mirror of https://github.com/google/gemma.cpp.git
Major revamp #2 of Prefill: fix token order, parallel for multi-query
- Allocate only the required KV caches and activation batch size - Add flags for batch sizes - Const-correct interface: Span of const int. - Also clean up the KVCache arg to a span. - Move kPrefillBatchSize into RuntimeConfig and remove related global constants. PiperOrigin-RevId: 655893197
This commit is contained in:
parent
c1f243c351
commit
aaf51898b6
|
|
@ -224,6 +224,7 @@ cc_library(
|
|||
":common",
|
||||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"@benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||
ByteStorageT backward =
|
||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||
KVCache kv_cache = KVCache::Create(info.model);
|
||||
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache = KVCache::Create(env.Info().model);
|
||||
KVCache kv_cache = KVCache::Create(
|
||||
env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size);
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@
|
|||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h" // StringFromType
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -76,10 +76,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
fprintf(stderr, "Loading model...\n");
|
||||
model_ = AllocateGemma(loader_, pool_);
|
||||
|
||||
kv_caches_.reserve(kBatchedQueryBatchSize);
|
||||
for (int i = 0; i < kBatchedQueryBatchSize; ++i) {
|
||||
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
|
||||
}
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] =
|
||||
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
|
||||
}
|
||||
|
||||
InitGenerator(inference_args_, gen_);
|
||||
|
|
@ -132,7 +132,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0],
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
|
@ -141,8 +141,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||
const hwy::Span<const hwy::Span<int>>& prompts) {
|
||||
std::vector<std::pair<std::string, size_t>> res(prompts.size());
|
||||
const MultiplePromptsTokens& prompts) {
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
std::vector<std::pair<std::string, size_t>> res(num_queries);
|
||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
||||
size_t total_tokens = 0;
|
||||
|
||||
|
|
@ -162,14 +164,29 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
|||
return true;
|
||||
};
|
||||
if (app_.verbosity >= 2) {
|
||||
std::cout << inference_args_.max_tokens << " "
|
||||
<< inference_args_.max_generated_tokens << " "
|
||||
<< inference_args_.temperature;
|
||||
fprintf(stderr,
|
||||
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
inference_args_.max_tokens, inference_args_.max_generated_tokens,
|
||||
inference_args_.temperature, inference_args_.prefill_tbatch_size,
|
||||
inference_args_.decode_qbatch_size);
|
||||
}
|
||||
|
||||
// Ensure we have one KVCache per query.
|
||||
if (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.resize(num_queries);
|
||||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(model_->Info().model,
|
||||
inference_args_.prefill_tbatch_size);
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::TimingInfo timing_info;
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_,
|
||||
timing_info);
|
||||
inference_args_.CopyTo(runtime_config_);
|
||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
@ -191,13 +208,12 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
|||
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||
/*pos=*/0, mutable_prompt));
|
||||
}
|
||||
std::vector<hwy::Span<int>> prompt_vector;
|
||||
std::vector<PromptTokens> prompt_vector;
|
||||
prompt_vector.reserve(prompts.size());
|
||||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size()));
|
||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
|
||||
prompt_vector.data(), prompt_vector.size());
|
||||
MultiplePromptsTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||
return BatchQueryModel2(prompt_span);
|
||||
}
|
||||
|
||||
|
|
@ -226,8 +242,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
if (app.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
// TODO: replace hardware_concurrency with detected topology.
|
||||
std::cout << "Date & Time : " << dt
|
||||
<< "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
|
||||
<< "Hardware concurrency : "
|
||||
<< std::thread::hardware_concurrency() << "\n"
|
||||
<< "Instruction set : "
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class GemmaEnv {
|
|||
// the number of tokens that were generated.
|
||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
||||
const hwy::Span<const hwy::Span<int>>& prompts);
|
||||
const MultiplePromptsTokens& prompts);
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||
|
|
@ -88,7 +88,7 @@ class GemmaEnv {
|
|||
const ModelInfo& Info() const { return loader_.Info(); }
|
||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return *kv_caches_[0]; }
|
||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||
|
||||
private:
|
||||
// Arguments to the model loader: file locations, etc.
|
||||
|
|
@ -103,8 +103,8 @@ class GemmaEnv {
|
|||
std::mt19937 gen_;
|
||||
// The model to run inference on.
|
||||
std::unique_ptr<Gemma> model_;
|
||||
// The KV cache to use for inference.
|
||||
std::vector<KVCache*> kv_caches_;
|
||||
// KV caches, same number as query batch.
|
||||
std::vector<KVCache> kv_caches_;
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded gemma weights.
|
||||
|
|
@ -75,21 +73,17 @@ class GemmaTest : public ::testing::Test {
|
|||
replies.push_back(response);
|
||||
}
|
||||
} else { // Not Gemma-2 27B. Do not use turn structure.
|
||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
||||
prompts.reserve(inputs.size());
|
||||
for (auto input_string : inputs) {
|
||||
std::string mutable_input_string = input_string;
|
||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
||||
s_env->TokenizeAndPrependBOS(input_string)));
|
||||
std::vector<std::vector<int>> prompts_vector;
|
||||
prompts_vector.reserve(inputs.size());
|
||||
for (const auto& input_string : inputs) {
|
||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<hwy::Span<int>> prompt_vector;
|
||||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(hwy::Span<int>(prompt->data(), prompt->size()));
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
hwy::Span<const hwy::Span<int>> prompt_span =
|
||||
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
|
||||
prompt_vector.size());
|
||||
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
|
||||
MultiplePromptsTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (auto [response, n] : s_env->BatchQueryModel2(prompts)) {
|
||||
replies.push_back(response);
|
||||
}
|
||||
}
|
||||
|
|
@ -121,18 +115,20 @@ class GemmaTest : public ::testing::Test {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, Geography) {
|
||||
TEST_F(GemmaTest, GeographyBatched) {
|
||||
s_env->MutableInferenceArgs().decode_qbatch_size = 3;
|
||||
// 6 are enough to test batching and the loop.
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Hungary?", "Budapest"},
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
{"What is the capital of Denmark?", "Copenhagen"},
|
||||
{"Ljubljana is the capital of which country?", "Slovenia"},
|
||||
{"Is Chicago a country?", "not"},
|
||||
{"How many states does the US have?", "50"},
|
||||
{"What is the Pacific?", "ocean"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
static const char* kQA_single_question[][2] = {
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
};
|
||||
TestQuestions(kQA_single_question, 1, /*batch=*/true);
|
||||
TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
|
||||
TestQuestions(kQA, 1, /*batch=*/true);
|
||||
TestQuestions(kQA, kNum, /*batch=*/true);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
return 0;
|
||||
|
|
@ -42,7 +43,8 @@ int main(int argc, char** argv) {
|
|||
// Instantiate model and KV Cache
|
||||
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
|
|||
|
|
@ -36,11 +36,6 @@ ByteStorageT AllocateSizeof() {
|
|||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Relatively small so that we can also parallelize non-Matmul work. There is
|
||||
// one outer thread per batch, each with --num_threads / batches inner threads.
|
||||
constexpr size_t kPrefillBatchSize = 64;
|
||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||
|
||||
// Model variants: see configs.h for details. When adding a new one, also
|
||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||
enum class Model {
|
||||
|
|
|
|||
|
|
@ -73,10 +73,10 @@ template <class TConfig>
|
|||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
||||
KVCache& kv_cache = *kv_caches[0];
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
|
@ -208,7 +208,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
size_t num_queries, size_t layer,
|
||||
Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Attention");
|
||||
HWY_DASSERT(interleaved_start % num_queries == 0);
|
||||
|
|
@ -221,6 +221,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||
|
||||
HWY_ASSERT(num_queries <= kv_caches.size());
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||
|
|
@ -245,9 +249,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
|
@ -268,10 +272,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
if constexpr (kIsMHA) {
|
||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||
|
|
@ -297,7 +301,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.Batch(interleaved_idx) + head * kQStride;
|
||||
|
||||
|
|
@ -314,10 +318,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
const size_t start_pos =
|
||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, kQKVDim);
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
}
|
||||
|
|
@ -337,7 +341,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v =
|
||||
|
|
@ -383,8 +387,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
|||
size_t num_tokens, size_t num_queries, size_t layer,
|
||||
Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool) {
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
||||
activations, layer_weights, kv_caches, pool);
|
||||
|
|
@ -458,12 +461,13 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
output_bias, pool);
|
||||
}
|
||||
|
||||
// TODO: pass Activations.x instead of Activations.
|
||||
// `pos` is for the entire batch and does not include `batch_idx`.
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations) {
|
||||
RowVectorBatch<float>& x) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
|
|
@ -472,11 +476,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
|||
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.Batch(batch_idx), kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
|
||||
x.Batch(batch_idx), kModelDim);
|
||||
MulByConst(kEmbScaling, x.Batch(batch_idx), kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim,
|
||||
pos + batch_idx);
|
||||
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), kModelDim, pos);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -501,7 +504,7 @@ template <class TConfig>
|
|||
HWY_NOINLINE void TransformerLayer(
|
||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
|
|
@ -536,116 +539,220 @@ HWY_NOINLINE void TransformerLayer(
|
|||
/*is_attention=*/false);
|
||||
}
|
||||
|
||||
// For prefill, we have two-level parallelism:
|
||||
// - Outer: input tokens are split into batches, each of which is one task
|
||||
// processed by a worker in `outer_pool_`, which includes the main thread
|
||||
// because it is the one that calls `Prefill`.
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
// decode depends on the previous output token. However, each prefill batch of a
|
||||
// query requires that preceding batches already wrote to the KV cache, hence we
|
||||
// sequentially loop over token batches. We can reduce the number of iterations
|
||||
// by increasing the batch size, but this also increases arithmetic intensity,
|
||||
// and so we are eventually compute-limited. The tensor parallelism (number of
|
||||
// threads collaborating on MatMul) is also limited by the CPU topology:
|
||||
// fork/join barriers are slow(er) when some threads reside in a different NUMA
|
||||
// node. To allow more threads to help, we also support parallelizing over
|
||||
// queries in case GenerateBatch was called.
|
||||
//
|
||||
// Thus we have two-level parallelism:
|
||||
// - Outer: handles one 'qbatch' of entire queries. The set of outer workers
|
||||
// includes the main thread because it is the one that calls `Prefill`, and is
|
||||
// determined by the number of 'clusters' (shared L3 caches or sockets).
|
||||
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
|
||||
// `TransformerLayer` for tensor-level parallelism.
|
||||
// `TransformerLayer` for tensor-level parallelism, and processes
|
||||
// `tbatch_size` tokens from a single query at a time.
|
||||
//
|
||||
// This class holds the thread pools and activations, recreated for each query.
|
||||
//
|
||||
// It is safe to parallelize batches because we write to KVCache at
|
||||
// `pos % kSeqLen`, which is far greater than the number of workers.
|
||||
// Note however that this currently leads to nondeterministic results because
|
||||
// the RNG is invoked in different order.
|
||||
// This class holds the thread pools and one activation per outer worker. It is
|
||||
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
|
||||
// to their num_queries.
|
||||
class PrefillState {
|
||||
public:
|
||||
explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {}
|
||||
|
||||
~PrefillState() { DeleteInnerPools(); }
|
||||
|
||||
// Called before each query. Recreates thread pools, which has the advantage
|
||||
// of tailoring the parallelism to the prompt length.
|
||||
template <class TConfig>
|
||||
void Init(size_t prefill_size) {
|
||||
// Would be zero for single-token prompts (prefill_size == num_tokens - 1).
|
||||
num_batches_ =
|
||||
HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize));
|
||||
// More than num_batches_ would waste workers on idling in the outer Run;
|
||||
// more than NumWorkers() would exceed the global --num_threads.
|
||||
const size_t outer_workers =
|
||||
HWY_MIN(num_batches_, main_pool_->NumWorkers());
|
||||
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
||||
|
||||
// One activation per outer worker. Allocating in parallel saves 30 ms.
|
||||
activations_.resize(outer_workers);
|
||||
main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) {
|
||||
activations_[task].Allocate<TConfig>(kPrefillBatchSize);
|
||||
// TODO: move helper functions, also those in app.h, to a threading header
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
LPS Intersection(const LPS& big, const LPS& small) {
|
||||
LPS both;
|
||||
// Reduce expected work by iterating over the smaller set.
|
||||
small.Foreach([big, &both](size_t idx) {
|
||||
if (big.Get(idx)) both.Set(idx);
|
||||
});
|
||||
return both;
|
||||
}
|
||||
|
||||
DeleteInnerPools();
|
||||
std::vector<size_t> CoresInLPS(const LPS& cluster) {
|
||||
std::vector<size_t> cores;
|
||||
cores.reserve(cluster.Count());
|
||||
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
|
||||
return cores;
|
||||
}
|
||||
|
||||
// If we'd create just one inner pool with all the workers, skip the cost of
|
||||
// thread creation and pinning (about 60 ms) by reusing the main pool.
|
||||
if (outer_workers <= 1) {
|
||||
// Still allocate a dummy pool to simplify Prefill().
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
||||
inner_pools_.push_back(main_pool_);
|
||||
return;
|
||||
// For each cluster (shared L3 cache), a bitset of cores.
|
||||
using CoresPerCluster = std::vector<LPS>;
|
||||
|
||||
// Returns empty if detection failed.
|
||||
CoresPerCluster DetectClusters() {
|
||||
CoresPerCluster clusters;
|
||||
// Which processors are not disabled via OS, taskset, or numactl.
|
||||
LPS enabled;
|
||||
// If we don't know, better to use just a single inner pool rather than risk
|
||||
// oversubscribing to enabled cores.
|
||||
if (!GetThreadAffinity(enabled)) return clusters;
|
||||
|
||||
hwy::Topology topology;
|
||||
if (topology.packages.empty()) return clusters;
|
||||
|
||||
// For each cluster = outer, the cores that will be used for an inner pool.
|
||||
CoresPerCluster inner_lps;
|
||||
for (const hwy::Topology::Package& package : topology.packages) {
|
||||
for (const hwy::Topology::Cluster& cluster : package.clusters) {
|
||||
// Only use enabled cores, and only add if not empty.
|
||||
const LPS lps = Intersection(enabled, cluster.lps);
|
||||
if (lps.Any()) clusters.push_back(lps);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by descending number of enabled cores, so that we preferentially
|
||||
// use the largest clusters.
|
||||
std::sort(clusters.begin(), clusters.end(),
|
||||
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
|
||||
|
||||
return clusters;
|
||||
}
|
||||
|
||||
// Returns false if the main pool should be reused instead.
|
||||
bool AssignInnerPoolsToClusters(const size_t num_queries) {
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
|
||||
CoresPerCluster inner_lps = DetectClusters();
|
||||
// If we have more outer workers than queries, discard the excess.
|
||||
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
|
||||
// If we're not going to create multiple pools, avoid the overhead of
|
||||
// re-pinning (60 ms) and reuse the main pool.
|
||||
if (inner_lps.size() <= 1) return false;
|
||||
|
||||
// Before creating new threads, stop the old ones from spinning. Caller is
|
||||
// responsible for undoing this by calling `ResumeMainSpinning`.
|
||||
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(outer_workers);
|
||||
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
|
||||
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
// Assign up to `max_workers` to inner pools. Each inner pool creates
|
||||
// `workers_per_outer - 1` threads in addition to its 'main' thread, which
|
||||
// is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total,
|
||||
// `outer_workers * (max_workers / outer_workers)` workers are used.
|
||||
const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers;
|
||||
for (size_t outer = 0; outer < outer_workers; ++outer) {
|
||||
inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer));
|
||||
HWY_ASSERT(inner_pools_.empty());
|
||||
for (const LPS& inner : inner_lps) {
|
||||
inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
|
||||
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
}
|
||||
|
||||
PinThreads(outer_workers, workers_per_outer);
|
||||
// For each inner pool, pin their threads AND the associated outer thread
|
||||
// to the enabled cores in the cluster.
|
||||
outer_pool_->Run(
|
||||
0, inner_lps.size(),
|
||||
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
|
||||
|
||||
inner_pools_[outer]->Run(
|
||||
0, cores.size(), [&cores](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each inner has one task
|
||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||
});
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// `tokens` are from interleaved queries. (See InterleaveQueries() below.)
|
||||
void ReuseMainPoolAsInner() {
|
||||
// Still allocate an empty pool to simplify Prefill().
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
||||
|
||||
HWY_ASSERT(inner_pools_.empty());
|
||||
inner_pools_.push_back(main_pool_);
|
||||
}
|
||||
|
||||
public:
|
||||
// Creates pools. AllocateActivations must still be called separately; it has
|
||||
// a template argument.
|
||||
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
|
||||
: main_pool_(&main_pool) {
|
||||
PROFILER_ZONE("Init.Prefill.Ctor");
|
||||
if (!AssignInnerPoolsToClusters(num_queries)) {
|
||||
ReuseMainPoolAsInner();
|
||||
}
|
||||
}
|
||||
|
||||
~PrefillState() {
|
||||
for (hwy::ThreadPool* p : inner_pools_) {
|
||||
if (p != main_pool_) delete p;
|
||||
}
|
||||
}
|
||||
|
||||
// `tbatch_size` is the number of tokens from one query to prefill at a time.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(hwy::Span<const int> tokens, size_t num_queries,
|
||||
size_t pos,
|
||||
void AllocateActivations(size_t num_queries, size_t tbatch_size) {
|
||||
PROFILER_ZONE("Init.Prefill.AllocateActivations");
|
||||
|
||||
const size_t outer_workers = outer_pool_->NumWorkers();
|
||||
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
||||
|
||||
HWY_ASSERT(activations_.empty()); // only call once.
|
||||
activations_.resize(outer_workers);
|
||||
|
||||
if (outer_workers == 1) {
|
||||
activations_[0].Allocate<TConfig>(tbatch_size);
|
||||
} else {
|
||||
// Allocating in parallel can save 30 ms.
|
||||
main_pool_->Run(0, outer_workers,
|
||||
[this, tbatch_size](uint64_t task, size_t /*thread*/) {
|
||||
activations_[task].Allocate<TConfig>(tbatch_size);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
|
||||
const size_t prefill_per_query, const size_t pos,
|
||||
const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<KVCache*>& kv_caches) {
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const size_t max_tbatch_size = activations_[0].x.BatchSize();
|
||||
|
||||
HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers());
|
||||
HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers());
|
||||
|
||||
// For each query (parallel): an outer worker processes all its tokens.
|
||||
// `qi` is relative to the batch, not the global query index.
|
||||
outer_pool_->Run(
|
||||
0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR {
|
||||
const size_t batch_start = batch_num * kPrefillBatchSize;
|
||||
const size_t batch_size =
|
||||
HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start);
|
||||
HWY_DASSERT(batch_start % num_queries == 0);
|
||||
HWY_DASSERT(batch_size % num_queries == 0);
|
||||
const size_t pos_per_query = pos + batch_start / num_queries;
|
||||
const size_t num_tokens = batch_size / num_queries;
|
||||
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
|
||||
Activations& activations = activations_[qthread];
|
||||
hwy::ThreadPool& inner_pool = *inner_pools_[qthread];
|
||||
|
||||
// Negligible time compared to TransformerLayer.
|
||||
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
EmbedToken<TConfig>(tokens[batch_start + batch_idx], batch_idx,
|
||||
pos_per_query, weights, activations_[thread]);
|
||||
}
|
||||
// Single query at a time, so pass a slice of the KV cache because
|
||||
// GemmaAttention will only access the first.
|
||||
const size_t kPrefillQueries = 1;
|
||||
KVCaches prefill_kv_caches(&kv_caches[qi], kPrefillQueries);
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(
|
||||
num_tokens, num_queries, pos_per_query, layer, layer_weights,
|
||||
activations_[thread], kv_caches, *inner_pools_[thread]);
|
||||
}
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = prompts[qi][tbatch_start + ti];
|
||||
EmbedToken<TConfig>(token, ti, pos + ti, weights, activations.x);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
const size_t query_idx = i % num_queries;
|
||||
const size_t batch_idx = i / num_queries;
|
||||
runtime_config.StreamToken(query_idx, pos_per_query + batch_idx,
|
||||
tokens[i], 0.0f);
|
||||
}
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(
|
||||
tbatch_size, kPrefillQueries, pos + tbatch_start, layer,
|
||||
layer_weights, activations, prefill_kv_caches, inner_pool);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = prompts[qi][tbatch_start + ti];
|
||||
runtime_config.StreamToken(query_idx_start + qi,
|
||||
pos + tbatch_start + ti, token, 0.0f);
|
||||
}
|
||||
} // for tbatch_start
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -663,39 +770,15 @@ class PrefillState {
|
|||
}
|
||||
|
||||
private:
|
||||
// Pins each outer thread after their inner threads so they are likely to
|
||||
// run on the same socket.
|
||||
void PinThreads(size_t outer_workers, size_t workers_per_outer) {
|
||||
outer_pool_->Run(
|
||||
0, outer_workers,
|
||||
[this, workers_per_outer](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread);
|
||||
// Pins inner *and* `outer` - the latter is the calling thread.
|
||||
inner_pools_[outer]->Run(
|
||||
0, workers_per_outer,
|
||||
[outer, workers_per_outer](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
const size_t lp = outer * workers_per_outer + task;
|
||||
hwy::PinThreadToLogicalProcessor(lp);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void DeleteInnerPools() {
|
||||
for (hwy::ThreadPool* p : inner_pools_) {
|
||||
if (p != main_pool_) delete p;
|
||||
}
|
||||
inner_pools_.clear();
|
||||
}
|
||||
|
||||
hwy::ThreadPool* main_pool_;
|
||||
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
|
||||
std::vector<Activations> activations_; // size == outer_pool->NumWorkers()
|
||||
// Either there is a single pointer equal to main_pool, or newly created pools
|
||||
// that we own. The former case avoids thread creation overhead for prompts
|
||||
// that fit in a single batch.
|
||||
// Holds a single pointer equal to main_pool_, or new allocations; in either
|
||||
// case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
|
||||
// allocation overhead for the common case of a single query.
|
||||
std::vector<hwy::ThreadPool*> inner_pools_;
|
||||
size_t num_batches_ = 0;
|
||||
|
||||
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
|
||||
std::vector<Activations> activations_;
|
||||
};
|
||||
|
||||
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
||||
|
|
@ -705,8 +788,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|||
size_t num_queries, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
const LayersOutputFunc& layers_output) {
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
if (layers_output) {
|
||||
|
|
@ -718,7 +800,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
||||
activations);
|
||||
activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
|
|
@ -781,10 +863,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
|
||||
// Returns interleaved tokens: one from each query, followed by the second from
|
||||
// all queries, with EOS padding.
|
||||
static std::vector<int> InterleaveQueries(
|
||||
const hwy::Span<const hwy::Span<int>>& queries,
|
||||
const RuntimeConfig& runtime_config, size_t& min_prompt_size,
|
||||
size_t& max_prompt_size) {
|
||||
static std::vector<int> InterleaveQueries(const MultiplePromptsTokens& queries,
|
||||
const RuntimeConfig& runtime_config,
|
||||
size_t& min_prompt_size,
|
||||
size_t& max_prompt_size) {
|
||||
const size_t num_queries = queries.size();
|
||||
min_prompt_size = hwy::LimitsMax<size_t>();
|
||||
max_prompt_size = 0;
|
||||
|
|
@ -829,28 +911,34 @@ class TokenStreamer {
|
|||
|
||||
private:
|
||||
const RuntimeConfig& runtime_config_;
|
||||
// BitSet4096 divides the arg by 64, so ensure it is at least 64.
|
||||
hwy::BitSet4096<HWY_MAX(64, kBatchedQueryBatchSize)> is_eos_;
|
||||
hwy::BitSet4096<> is_eos_;
|
||||
};
|
||||
|
||||
// Generates one token per query in the batch.
|
||||
// Generates one token for each query in `prompts`, which is one qbatch whose
|
||||
// size is at most the `batch_size` passed to `activations.Allocate`.
|
||||
//
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
||||
// `pos` indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
||||
// continues to increase by one for each prefilled/generated token per query.
|
||||
// query_idx_start is the first query index in the batch.
|
||||
template <class TConfig, size_t kQueryBatchSize>
|
||||
//
|
||||
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
||||
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||
//
|
||||
// `kv_caches` is for the batch, size must match `prompts`.
|
||||
template <class TConfig>
|
||||
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, const size_t pos,
|
||||
const size_t query_idx_start,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
const MultiplePromptsTokens& prompts, const size_t pos,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
|
||||
size_t min_prompt_size, max_prompt_size;
|
||||
const std::vector<int> prompt = InterleaveQueries(
|
||||
prompts, runtime_config, min_prompt_size, max_prompt_size);
|
||||
|
|
@ -877,28 +965,28 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||
// the first input token for generation.
|
||||
const size_t prefill_per_query = min_prompt_size - 1;
|
||||
const hwy::Span<const int> prefill_tokens(prompt.data(),
|
||||
prefill_per_query * num_queries);
|
||||
PrefillState prefill(pool);
|
||||
prefill.Init<TConfig>(prefill_tokens.size());
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
size_t interleaved_pos = pos * num_queries;
|
||||
prefill.Prefill<TConfig>(prefill_tokens, num_queries, interleaved_pos,
|
||||
weights, runtime_config, kv_caches);
|
||||
interleaved_pos += prefill_tokens.size();
|
||||
timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start);
|
||||
double prefill_start;
|
||||
{
|
||||
PrefillState prefill(pool, num_queries);
|
||||
prefill.AllocateActivations<TConfig>(num_queries,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
prefill_start = hwy::platform::Now();
|
||||
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
||||
weights, runtime_config, kv_caches);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
prefill.ResumeMainSpinning();
|
||||
}
|
||||
|
||||
prefill.ResumeMainSpinning();
|
||||
size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
// Transformer() call.
|
||||
std::vector<int> gen_tokens(num_queries);
|
||||
|
||||
// Stream the last prompt token from each query and fill gen_tokens.
|
||||
hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(),
|
||||
num_queries * sizeof(prompt[0]));
|
||||
TokenStreamer token_streamer(runtime_config);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
||||
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
||||
gen_tokens[query_idx], 0.0f);
|
||||
}
|
||||
|
|
@ -940,42 +1028,49 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
timing_info.NotifyGenerateDone(gen_start);
|
||||
}
|
||||
|
||||
// TODO: prompt should also be span, not a vector.
|
||||
template <class TConfig>
|
||||
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations,
|
||||
void GenerateSingleT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()),
|
||||
prompt.size());
|
||||
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1);
|
||||
// TODO: also span of kv_cache, or batching inside KVCache?
|
||||
std::vector<KVCache*> kv_caches = {&kv_cache};
|
||||
const size_t query_idx_start = 0;
|
||||
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
||||
weights_u8, activations, runtime_config, prompts, pos, query_idx_start,
|
||||
kv_caches, pool, timing_info);
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
const size_t num_queries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
Activations activations;
|
||||
activations.Allocate<TConfig>(num_queries);
|
||||
|
||||
const MultiplePromptsTokens prompts(&prompt, num_queries);
|
||||
const KVCaches kv_caches{&kv_cache, num_queries};
|
||||
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
|
||||
qbatch_start, kv_caches, pool, timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
// Disable query batching for Griffin models.
|
||||
constexpr size_t kQueryBatchSize =
|
||||
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
||||
for (size_t query_idx_start = 0; query_idx_start < prompts.size();
|
||||
query_idx_start += kQueryBatchSize) {
|
||||
const size_t num_queries =
|
||||
std::min(prompts.size() - query_idx_start, kQueryBatchSize);
|
||||
const hwy::Span<const hwy::Span<int>> query_batch(
|
||||
prompts.data() + query_idx_start, num_queries);
|
||||
GenerateT<TConfig, kQueryBatchSize>(weights_u8, activations, runtime_config,
|
||||
query_batch, pos, query_idx_start,
|
||||
kv_caches, pool, timing_info);
|
||||
const MultiplePromptsTokens& prompts, size_t pos,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
// Griffin does not support query batching.
|
||||
const size_t max_qbatch_size =
|
||||
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
||||
|
||||
Activations activations;
|
||||
activations.Allocate<TConfig>(max_qbatch_size);
|
||||
|
||||
const size_t num_queries = prompts.size();
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
// Generate one batch of tokens from `qbatch_size` queries.
|
||||
const size_t qbatch_size =
|
||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||
pos, qbatch_start, qbatch_kv, pool, timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -986,24 +1081,20 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// These are extern functions defined by instantiations/*.cc, which include this
|
||||
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
|
||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||
(weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool,
|
||||
timing_info);
|
||||
(weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
||||
size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||
(weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool,
|
||||
timing_info);
|
||||
(weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -24,32 +24,19 @@
|
|||
#include <string.h>
|
||||
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename TConfig>
|
||||
struct AllocateActivations {
|
||||
void operator()(Activations& decode) const {
|
||||
// TODO: this is wasted if we only have single-batch queries. Instead
|
||||
// re-allocate when the query batch size is actually > 1?
|
||||
decode.Allocate<TConfig>(kBatchedQueryBatchSize);
|
||||
}
|
||||
};
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
|
|
@ -58,7 +45,6 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
|||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
weights_u8_ =
|
||||
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
|
|
@ -70,67 +56,64 @@ Gemma::~Gemma() {
|
|||
// we shard them across multiple translation units in instantiations/*.cc.
|
||||
// This declares the functions defined there. We use overloading because
|
||||
// explicit instantiations are still too slow to compile.
|
||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||
extern void GenerateSingle( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
|
||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
||||
TimingInfo& timing_info);
|
||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const MultiplePromptsTokens& prompts, size_t pos, \
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool, \
|
||||
TimingInfo& timing_info);
|
||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelAndWeight.
|
||||
// TODO: gather all ByteStorageT into a type-erased model struct?
|
||||
template <class TConfig>
|
||||
struct GenerateSingleT {
|
||||
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos,
|
||||
kv_cache, pool, timing_info);
|
||||
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
|
||||
pool, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateBatchT {
|
||||
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||
const MultiplePromptsTokens& prompts, size_t pos,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos,
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
|
||||
kv_caches, pool, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
const PromptTokens& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
CallForModelAndWeight<GenerateSingleT>(
|
||||
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt,
|
||||
start_pos, kv_cache, pool_, timing_info);
|
||||
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompt, start_pos,
|
||||
kv_cache, pool_, timing_info);
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
||||
size_t start_pos,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
size_t start_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
CallForModelAndWeight<GenerateBatchT>(
|
||||
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts,
|
||||
start_pos, kv_caches, pool_, timing_info);
|
||||
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompts, start_pos,
|
||||
kv_caches, pool_, timing_info);
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -67,6 +67,13 @@ struct RuntimeConfig {
|
|||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 32;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
float temperature;
|
||||
int verbosity;
|
||||
std::mt19937* gen;
|
||||
|
|
@ -105,6 +112,10 @@ struct TimingInfo {
|
|||
size_t tokens_generated;
|
||||
};
|
||||
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
using MultiplePromptsTokens = hwy::Span<const PromptTokens>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
|
|
@ -118,25 +129,20 @@ class Gemma {
|
|||
const ModelInfo& Info() const { return info_; }
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
const Activations& Decode() const { return decode_; }
|
||||
|
||||
void Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info);
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);
|
||||
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
||||
size_t start_pos, const std::vector<KVCache*>& kv_caches,
|
||||
TimingInfo& timing_info);
|
||||
const MultiplePromptsTokens& prompts, size_t start_pos,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
|
||||
private:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header, without requiring
|
||||
// forwarding functions.
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ByteStorageT weights_u8_;
|
||||
Activations decode_;
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,13 +23,16 @@ namespace gcpp {
|
|||
namespace {
|
||||
template <class TConfig>
|
||||
struct CreateKVCache {
|
||||
KVCache operator()() const {
|
||||
KVCache operator()(size_t prefill_tbatch_size) const {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||
if (size_cache_pos != 0) {
|
||||
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize);
|
||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
// Allocate more so that prefill can always access one batch, even if
|
||||
// near the end of the sequence.
|
||||
kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size;
|
||||
kv_cache.kv_cache =
|
||||
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
|
|
@ -58,10 +61,13 @@ struct CreateKVCache {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
KVCache KVCache::Create(Model model_type) {
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) {
|
||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||
// use TConfig::Weight.
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type,
|
||||
prefill_tbatch_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,13 +16,17 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/common.h" // Model
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct KVCache {
|
||||
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||
|
||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
|
||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
|
|
@ -31,7 +35,7 @@ struct KVCache {
|
|||
// kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||
|
||||
static KVCache Create(Model type);
|
||||
static KVCache Create(Model type, size_t prefill_tbatch_size);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -145,14 +145,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
|||
|
||||
TimingInfo timing_info;
|
||||
RuntimeConfig runtime_config = {
|
||||
.max_tokens = args.max_tokens,
|
||||
.max_generated_tokens = args.max_generated_tokens,
|
||||
.temperature = args.temperature,
|
||||
.verbosity = verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
args.CopyTo(runtime_config);
|
||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
|
|
@ -181,7 +179,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
Gemma model = CreateGemma(loader, pool);
|
||||
KVCache kv_cache = KVCache::Create(model.Info().model);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
|
|||
18
util/app.h
18
util/app.h
|
|
@ -248,6 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_qbatch_size;
|
||||
|
||||
float temperature;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
|
|
@ -272,6 +275,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{64},
|
||||
"Prefill: max tokens per batch.");
|
||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||
"Decode: max queries per batch.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
|
|
@ -281,6 +289,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
runtime_config.max_tokens = max_tokens;
|
||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
|
||||
runtime_config.temperature = temperature;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue