mirror of https://github.com/google/gemma.cpp.git
Major revamp #2 of Prefill: fix token order, parallel for multi-query
- Allocate only the required KV caches and activation batch size - Add flags for batch sizes - Const-correct interface: Span of const int. - Also clean up the KVCache arg to a span. - Move kPrefillBatchSize into RuntimeConfig and remove related global constants. PiperOrigin-RevId: 655893197
This commit is contained in:
parent
c1f243c351
commit
aaf51898b6
|
|
@ -224,6 +224,7 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
":kv_cache",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"@benchmark//:benchmark",
|
"@benchmark//:benchmark",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||||
ByteStorageT backward =
|
ByteStorageT backward =
|
||||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||||
KVCache kv_cache = KVCache::Create(info.model);
|
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
||||||
|
|
||||||
Gemma gemma(GemmaTokenizer(), info, pool);
|
Gemma gemma(GemmaTokenizer(), info, pool);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache = KVCache::Create(env.Info().model);
|
KVCache kv_cache = KVCache::Create(
|
||||||
|
env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size);
|
||||||
float entropy = ComputeCrossEntropy(
|
float entropy = ComputeCrossEntropy(
|
||||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,9 @@
|
||||||
#include "evals/cross_entropy.h"
|
#include "evals/cross_entropy.h"
|
||||||
#include "gemma/common.h" // StringFromType
|
#include "gemma/common.h" // StringFromType
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -76,10 +76,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
fprintf(stderr, "Loading model...\n");
|
fprintf(stderr, "Loading model...\n");
|
||||||
model_ = AllocateGemma(loader_, pool_);
|
model_ = AllocateGemma(loader_, pool_);
|
||||||
|
|
||||||
kv_caches_.reserve(kBatchedQueryBatchSize);
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
for (int i = 0; i < kBatchedQueryBatchSize; ++i) {
|
kv_caches_.resize(1);
|
||||||
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
|
kv_caches_[0] =
|
||||||
}
|
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
InitGenerator(inference_args_, gen_);
|
InitGenerator(inference_args_, gen_);
|
||||||
|
|
@ -132,7 +132,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0],
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
if (app_.verbosity >= 1) {
|
if (app_.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
|
@ -141,8 +141,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts) {
|
const MultiplePromptsTokens& prompts) {
|
||||||
std::vector<std::pair<std::string, size_t>> res(prompts.size());
|
const size_t num_queries = prompts.size();
|
||||||
|
HWY_ASSERT(num_queries != 0);
|
||||||
|
std::vector<std::pair<std::string, size_t>> res(num_queries);
|
||||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
||||||
size_t total_tokens = 0;
|
size_t total_tokens = 0;
|
||||||
|
|
||||||
|
|
@ -162,14 +164,29 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (app_.verbosity >= 2) {
|
if (app_.verbosity >= 2) {
|
||||||
std::cout << inference_args_.max_tokens << " "
|
fprintf(stderr,
|
||||||
<< inference_args_.max_generated_tokens << " "
|
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||||
<< inference_args_.temperature;
|
inference_args_.max_tokens, inference_args_.max_generated_tokens,
|
||||||
|
inference_args_.temperature, inference_args_.prefill_tbatch_size,
|
||||||
|
inference_args_.decode_qbatch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure we have one KVCache per query.
|
||||||
|
if (kv_caches_.size() < num_queries) {
|
||||||
|
kv_caches_.resize(num_queries);
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < num_queries; ++i) {
|
||||||
|
if (kv_caches_[i].seq_len == 0) {
|
||||||
|
kv_caches_[i] = KVCache::Create(model_->Info().model,
|
||||||
|
inference_args_.prefill_tbatch_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_,
|
inference_args_.CopyTo(runtime_config_);
|
||||||
timing_info);
|
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
|
||||||
|
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||||
if (app_.verbosity >= 1) {
|
if (app_.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
|
|
@ -191,13 +208,12 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||||
/*pos=*/0, mutable_prompt));
|
/*pos=*/0, mutable_prompt));
|
||||||
}
|
}
|
||||||
std::vector<hwy::Span<int>> prompt_vector;
|
std::vector<PromptTokens> prompt_vector;
|
||||||
prompt_vector.reserve(prompts.size());
|
prompt_vector.reserve(prompts.size());
|
||||||
for (auto& prompt : prompts) {
|
for (auto& prompt : prompts) {
|
||||||
prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size()));
|
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
|
MultiplePromptsTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||||
prompt_vector.data(), prompt_vector.size());
|
|
||||||
return BatchQueryModel2(prompt_span);
|
return BatchQueryModel2(prompt_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -226,8 +242,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
if (app.verbosity >= 2) {
|
if (app.verbosity >= 2) {
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
|
// TODO: replace hardware_concurrency with detected topology.
|
||||||
std::cout << "Date & Time : " << dt
|
std::cout << "Date & Time : " << dt
|
||||||
<< "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
|
|
||||||
<< "Hardware concurrency : "
|
<< "Hardware concurrency : "
|
||||||
<< std::thread::hardware_concurrency() << "\n"
|
<< std::thread::hardware_concurrency() << "\n"
|
||||||
<< "Instruction set : "
|
<< "Instruction set : "
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class GemmaEnv {
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts);
|
const MultiplePromptsTokens& prompts);
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||||
|
|
@ -88,7 +88,7 @@ class GemmaEnv {
|
||||||
const ModelInfo& Info() const { return loader_.Info(); }
|
const ModelInfo& Info() const { return loader_.Info(); }
|
||||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||||
std::mt19937& MutableGen() { return gen_; }
|
std::mt19937& MutableGen() { return gen_; }
|
||||||
KVCache& MutableKVCache() { return *kv_caches_[0]; }
|
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Arguments to the model loader: file locations, etc.
|
// Arguments to the model loader: file locations, etc.
|
||||||
|
|
@ -103,8 +103,8 @@ class GemmaEnv {
|
||||||
std::mt19937 gen_;
|
std::mt19937 gen_;
|
||||||
// The model to run inference on.
|
// The model to run inference on.
|
||||||
std::unique_ptr<Gemma> model_;
|
std::unique_ptr<Gemma> model_;
|
||||||
// The KV cache to use for inference.
|
// KV caches, same number as query batch.
|
||||||
std::vector<KVCache*> kv_caches_;
|
std::vector<KVCache> kv_caches_;
|
||||||
RuntimeConfig runtime_config_;
|
RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,12 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// This test can be run manually with the downloaded gemma weights.
|
// This test can be run manually with the downloaded gemma weights.
|
||||||
|
|
@ -75,21 +73,17 @@ class GemmaTest : public ::testing::Test {
|
||||||
replies.push_back(response);
|
replies.push_back(response);
|
||||||
}
|
}
|
||||||
} else { // Not Gemma-2 27B. Do not use turn structure.
|
} else { // Not Gemma-2 27B. Do not use turn structure.
|
||||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
std::vector<std::vector<int>> prompts_vector;
|
||||||
prompts.reserve(inputs.size());
|
prompts_vector.reserve(inputs.size());
|
||||||
for (auto input_string : inputs) {
|
for (const auto& input_string : inputs) {
|
||||||
std::string mutable_input_string = input_string;
|
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
|
||||||
s_env->TokenizeAndPrependBOS(input_string)));
|
|
||||||
}
|
}
|
||||||
std::vector<hwy::Span<int>> prompt_vector;
|
std::vector<PromptTokens> prompt_spans;
|
||||||
for (auto& prompt : prompts) {
|
for (const auto& prompt : prompts_vector) {
|
||||||
prompt_vector.push_back(hwy::Span<int>(prompt->data(), prompt->size()));
|
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
hwy::Span<const hwy::Span<int>> prompt_span =
|
MultiplePromptsTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||||
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
|
for (auto [response, n] : s_env->BatchQueryModel2(prompts)) {
|
||||||
prompt_vector.size());
|
|
||||||
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
|
|
||||||
replies.push_back(response);
|
replies.push_back(response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -121,18 +115,20 @@ class GemmaTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GemmaTest, Geography) {
|
TEST_F(GemmaTest, GeographyBatched) {
|
||||||
|
s_env->MutableInferenceArgs().decode_qbatch_size = 3;
|
||||||
|
// 6 are enough to test batching and the loop.
|
||||||
static const char* kQA[][2] = {
|
static const char* kQA[][2] = {
|
||||||
{"What is the capital of Hungary?", "Budapest"},
|
|
||||||
{"What is the capital of Australia?", "Canberra"},
|
{"What is the capital of Australia?", "Canberra"},
|
||||||
|
{"What is the capital of Denmark?", "Copenhagen"},
|
||||||
|
{"Ljubljana is the capital of which country?", "Slovenia"},
|
||||||
|
{"Is Chicago a country?", "not"},
|
||||||
{"How many states does the US have?", "50"},
|
{"How many states does the US have?", "50"},
|
||||||
|
{"What is the Pacific?", "ocean"},
|
||||||
};
|
};
|
||||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
|
||||||
static const char* kQA_single_question[][2] = {
|
TestQuestions(kQA, 1, /*batch=*/true);
|
||||||
{"What is the capital of Australia?", "Canberra"},
|
|
||||||
};
|
|
||||||
TestQuestions(kQA_single_question, 1, /*batch=*/true);
|
|
||||||
TestQuestions(kQA, kNum, /*batch=*/true);
|
TestQuestions(kQA, kNum, /*batch=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
loader.Help();
|
loader.Help();
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -42,7 +43,8 @@ int main(int argc, char** argv) {
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
|
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model);
|
gcpp::KVCache kv_cache =
|
||||||
|
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||||
size_t pos = 0; // KV Cache position
|
size_t pos = 0; // KV Cache position
|
||||||
|
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,6 @@ ByteStorageT AllocateSizeof() {
|
||||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Relatively small so that we can also parallelize non-Matmul work. There is
|
|
||||||
// one outer thread per batch, each with --num_threads / batches inner threads.
|
|
||||||
constexpr size_t kPrefillBatchSize = 64;
|
|
||||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
|
||||||
|
|
||||||
// Model variants: see configs.h for details. When adding a new one, also
|
// Model variants: see configs.h for details. When adding a new one, also
|
||||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||||
enum class Model {
|
enum class Model {
|
||||||
|
|
|
||||||
|
|
@ -73,10 +73,10 @@ template <class TConfig>
|
||||||
HWY_NOINLINE void GriffinRecurrent(
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
||||||
KVCache& kv_cache = *kv_caches[0];
|
KVCache& kv_cache = kv_caches[0];
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -208,7 +208,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
size_t num_queries, size_t layer,
|
size_t num_queries, size_t layer,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const KVCaches& kv_caches,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
HWY_DASSERT(interleaved_start % num_queries == 0);
|
HWY_DASSERT(interleaved_start % num_queries == 0);
|
||||||
|
|
@ -221,6 +221,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||||
|
|
||||||
|
HWY_ASSERT(num_queries <= kv_caches.size());
|
||||||
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
|
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||||
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||||
|
|
@ -245,9 +249,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = kv_caches[query_idx];
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
|
@ -268,10 +272,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
if constexpr (kIsMHA) {
|
if constexpr (kIsMHA) {
|
||||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||||
|
|
@ -297,7 +301,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.Batch(interleaved_idx) + head * kQStride;
|
activations.q.Batch(interleaved_idx) + head * kQStride;
|
||||||
|
|
||||||
|
|
@ -314,10 +318,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||||
const float score = Dot(q, k, kQKVDim);
|
const float score = Dot(q, k, kQKVDim);
|
||||||
head_att[pos2 % kSeqLen] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
|
|
@ -337,7 +341,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
|
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
float* HWY_RESTRICT v =
|
float* HWY_RESTRICT v =
|
||||||
|
|
@ -383,8 +387,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
||||||
size_t num_tokens, size_t num_queries, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
||||||
activations, layer_weights, kv_caches, pool);
|
activations, layer_weights, kv_caches, pool);
|
||||||
|
|
@ -458,12 +461,13 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
output_bias, pool);
|
output_bias, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: pass Activations.x instead of Activations.
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
// `pos` is for the entire batch and does not include `batch_idx`.
|
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||||
|
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations& activations) {
|
RowVectorBatch<float>& x) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
|
|
@ -472,11 +476,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||||
|
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.Batch(batch_idx), kModelDim);
|
x.Batch(batch_idx), kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
|
MulByConst(kEmbScaling, x.Batch(batch_idx), kModelDim);
|
||||||
if constexpr (TConfig::kAbsolutePE) {
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim,
|
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), kModelDim, pos);
|
||||||
pos + batch_idx);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -501,7 +504,7 @@ template <class TConfig>
|
||||||
HWY_NOINLINE void TransformerLayer(
|
HWY_NOINLINE void TransformerLayer(
|
||||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
|
|
@ -536,116 +539,220 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
/*is_attention=*/false);
|
/*is_attention=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For prefill, we have two-level parallelism:
|
// Batches are important for amortizing loading weights over multiple tokens.
|
||||||
// - Outer: input tokens are split into batches, each of which is one task
|
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||||
// processed by a worker in `outer_pool_`, which includes the main thread
|
// decode depends on the previous output token. However, each prefill batch of a
|
||||||
// because it is the one that calls `Prefill`.
|
// query requires that preceding batches already wrote to the KV cache, hence we
|
||||||
|
// sequentially loop over token batches. We can reduce the number of iterations
|
||||||
|
// by increasing the batch size, but this also increases arithmetic intensity,
|
||||||
|
// and so we are eventually compute-limited. The tensor parallelism (number of
|
||||||
|
// threads collaborating on MatMul) is also limited by the CPU topology:
|
||||||
|
// fork/join barriers are slow(er) when some threads reside in a different NUMA
|
||||||
|
// node. To allow more threads to help, we also support parallelizing over
|
||||||
|
// queries in case GenerateBatch was called.
|
||||||
|
//
|
||||||
|
// Thus we have two-level parallelism:
|
||||||
|
// - Outer: handles one 'qbatch' of entire queries. The set of outer workers
|
||||||
|
// includes the main thread because it is the one that calls `Prefill`, and is
|
||||||
|
// determined by the number of 'clusters' (shared L3 caches or sockets).
|
||||||
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
|
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
|
||||||
// `TransformerLayer` for tensor-level parallelism.
|
// `TransformerLayer` for tensor-level parallelism, and processes
|
||||||
|
// `tbatch_size` tokens from a single query at a time.
|
||||||
//
|
//
|
||||||
// This class holds the thread pools and activations, recreated for each query.
|
// This class holds the thread pools and one activation per outer worker. It is
|
||||||
//
|
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
|
||||||
// It is safe to parallelize batches because we write to KVCache at
|
// to their num_queries.
|
||||||
// `pos % kSeqLen`, which is far greater than the number of workers.
|
|
||||||
// Note however that this currently leads to nondeterministic results because
|
|
||||||
// the RNG is invoked in different order.
|
|
||||||
class PrefillState {
|
class PrefillState {
|
||||||
public:
|
// TODO: move helper functions, also those in app.h, to a threading header
|
||||||
explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {}
|
using LPS = hwy::LogicalProcessorSet;
|
||||||
|
LPS Intersection(const LPS& big, const LPS& small) {
|
||||||
~PrefillState() { DeleteInnerPools(); }
|
LPS both;
|
||||||
|
// Reduce expected work by iterating over the smaller set.
|
||||||
// Called before each query. Recreates thread pools, which has the advantage
|
small.Foreach([big, &both](size_t idx) {
|
||||||
// of tailoring the parallelism to the prompt length.
|
if (big.Get(idx)) both.Set(idx);
|
||||||
template <class TConfig>
|
|
||||||
void Init(size_t prefill_size) {
|
|
||||||
// Would be zero for single-token prompts (prefill_size == num_tokens - 1).
|
|
||||||
num_batches_ =
|
|
||||||
HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize));
|
|
||||||
// More than num_batches_ would waste workers on idling in the outer Run;
|
|
||||||
// more than NumWorkers() would exceed the global --num_threads.
|
|
||||||
const size_t outer_workers =
|
|
||||||
HWY_MIN(num_batches_, main_pool_->NumWorkers());
|
|
||||||
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
|
||||||
|
|
||||||
// One activation per outer worker. Allocating in parallel saves 30 ms.
|
|
||||||
activations_.resize(outer_workers);
|
|
||||||
main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) {
|
|
||||||
activations_[task].Allocate<TConfig>(kPrefillBatchSize);
|
|
||||||
});
|
});
|
||||||
|
return both;
|
||||||
|
}
|
||||||
|
|
||||||
DeleteInnerPools();
|
std::vector<size_t> CoresInLPS(const LPS& cluster) {
|
||||||
|
std::vector<size_t> cores;
|
||||||
|
cores.reserve(cluster.Count());
|
||||||
|
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
|
||||||
|
return cores;
|
||||||
|
}
|
||||||
|
|
||||||
// If we'd create just one inner pool with all the workers, skip the cost of
|
// For each cluster (shared L3 cache), a bitset of cores.
|
||||||
// thread creation and pinning (about 60 ms) by reusing the main pool.
|
using CoresPerCluster = std::vector<LPS>;
|
||||||
if (outer_workers <= 1) {
|
|
||||||
// Still allocate a dummy pool to simplify Prefill().
|
// Returns empty if detection failed.
|
||||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
CoresPerCluster DetectClusters() {
|
||||||
inner_pools_.push_back(main_pool_);
|
CoresPerCluster clusters;
|
||||||
return;
|
// Which processors are not disabled via OS, taskset, or numactl.
|
||||||
|
LPS enabled;
|
||||||
|
// If we don't know, better to use just a single inner pool rather than risk
|
||||||
|
// oversubscribing to enabled cores.
|
||||||
|
if (!GetThreadAffinity(enabled)) return clusters;
|
||||||
|
|
||||||
|
hwy::Topology topology;
|
||||||
|
if (topology.packages.empty()) return clusters;
|
||||||
|
|
||||||
|
// For each cluster = outer, the cores that will be used for an inner pool.
|
||||||
|
CoresPerCluster inner_lps;
|
||||||
|
for (const hwy::Topology::Package& package : topology.packages) {
|
||||||
|
for (const hwy::Topology::Cluster& cluster : package.clusters) {
|
||||||
|
// Only use enabled cores, and only add if not empty.
|
||||||
|
const LPS lps = Intersection(enabled, cluster.lps);
|
||||||
|
if (lps.Any()) clusters.push_back(lps);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort by descending number of enabled cores, so that we preferentially
|
||||||
|
// use the largest clusters.
|
||||||
|
std::sort(clusters.begin(), clusters.end(),
|
||||||
|
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
|
||||||
|
|
||||||
|
return clusters;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns false if the main pool should be reused instead.
|
||||||
|
bool AssignInnerPoolsToClusters(const size_t num_queries) {
|
||||||
|
HWY_ASSERT(num_queries != 0);
|
||||||
|
|
||||||
|
CoresPerCluster inner_lps = DetectClusters();
|
||||||
|
// If we have more outer workers than queries, discard the excess.
|
||||||
|
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
|
||||||
|
// If we're not going to create multiple pools, avoid the overhead of
|
||||||
|
// re-pinning (60 ms) and reuse the main pool.
|
||||||
|
if (inner_lps.size() <= 1) return false;
|
||||||
|
|
||||||
// Before creating new threads, stop the old ones from spinning. Caller is
|
// Before creating new threads, stop the old ones from spinning. Caller is
|
||||||
// responsible for undoing this by calling `ResumeMainSpinning`.
|
// responsible for undoing this by calling `ResumeMainSpinning`.
|
||||||
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(outer_workers);
|
|
||||||
|
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
|
||||||
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
// Assign up to `max_workers` to inner pools. Each inner pool creates
|
HWY_ASSERT(inner_pools_.empty());
|
||||||
// `workers_per_outer - 1` threads in addition to its 'main' thread, which
|
for (const LPS& inner : inner_lps) {
|
||||||
// is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total,
|
inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
|
||||||
// `outer_workers * (max_workers / outer_workers)` workers are used.
|
|
||||||
const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers;
|
|
||||||
for (size_t outer = 0; outer < outer_workers; ++outer) {
|
|
||||||
inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer));
|
|
||||||
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
}
|
}
|
||||||
|
|
||||||
PinThreads(outer_workers, workers_per_outer);
|
// For each inner pool, pin their threads AND the associated outer thread
|
||||||
|
// to the enabled cores in the cluster.
|
||||||
|
outer_pool_->Run(
|
||||||
|
0, inner_lps.size(),
|
||||||
|
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
|
||||||
|
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||||
|
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
|
||||||
|
|
||||||
|
inner_pools_[outer]->Run(
|
||||||
|
0, cores.size(), [&cores](uint64_t task, size_t thread) {
|
||||||
|
HWY_ASSERT(task == thread); // each inner has one task
|
||||||
|
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// `tokens` are from interleaved queries. (See InterleaveQueries() below.)
|
void ReuseMainPoolAsInner() {
|
||||||
|
// Still allocate an empty pool to simplify Prefill().
|
||||||
|
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
||||||
|
|
||||||
|
HWY_ASSERT(inner_pools_.empty());
|
||||||
|
inner_pools_.push_back(main_pool_);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Creates pools. AllocateActivations must still be called separately; it has
|
||||||
|
// a template argument.
|
||||||
|
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
|
||||||
|
: main_pool_(&main_pool) {
|
||||||
|
PROFILER_ZONE("Init.Prefill.Ctor");
|
||||||
|
if (!AssignInnerPoolsToClusters(num_queries)) {
|
||||||
|
ReuseMainPoolAsInner();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~PrefillState() {
|
||||||
|
for (hwy::ThreadPool* p : inner_pools_) {
|
||||||
|
if (p != main_pool_) delete p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `tbatch_size` is the number of tokens from one query to prefill at a time.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Prefill(hwy::Span<const int> tokens, size_t num_queries,
|
void AllocateActivations(size_t num_queries, size_t tbatch_size) {
|
||||||
size_t pos,
|
PROFILER_ZONE("Init.Prefill.AllocateActivations");
|
||||||
|
|
||||||
|
const size_t outer_workers = outer_pool_->NumWorkers();
|
||||||
|
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
||||||
|
|
||||||
|
HWY_ASSERT(activations_.empty()); // only call once.
|
||||||
|
activations_.resize(outer_workers);
|
||||||
|
|
||||||
|
if (outer_workers == 1) {
|
||||||
|
activations_[0].Allocate<TConfig>(tbatch_size);
|
||||||
|
} else {
|
||||||
|
// Allocating in parallel can save 30 ms.
|
||||||
|
main_pool_->Run(0, outer_workers,
|
||||||
|
[this, tbatch_size](uint64_t task, size_t /*thread*/) {
|
||||||
|
activations_[task].Allocate<TConfig>(tbatch_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
|
||||||
|
const size_t prefill_per_query, const size_t pos,
|
||||||
|
const size_t query_idx_start,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<KVCache*>& kv_caches) {
|
const KVCaches& kv_caches) {
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
|
const size_t num_queries = prompts.size();
|
||||||
|
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||||
|
const size_t max_tbatch_size = activations_[0].x.BatchSize();
|
||||||
|
|
||||||
HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers());
|
// For each query (parallel): an outer worker processes all its tokens.
|
||||||
HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers());
|
// `qi` is relative to the batch, not the global query index.
|
||||||
|
|
||||||
outer_pool_->Run(
|
outer_pool_->Run(
|
||||||
0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR {
|
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
|
||||||
const size_t batch_start = batch_num * kPrefillBatchSize;
|
Activations& activations = activations_[qthread];
|
||||||
const size_t batch_size =
|
hwy::ThreadPool& inner_pool = *inner_pools_[qthread];
|
||||||
HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start);
|
|
||||||
HWY_DASSERT(batch_start % num_queries == 0);
|
|
||||||
HWY_DASSERT(batch_size % num_queries == 0);
|
|
||||||
const size_t pos_per_query = pos + batch_start / num_queries;
|
|
||||||
const size_t num_tokens = batch_size / num_queries;
|
|
||||||
|
|
||||||
// Negligible time compared to TransformerLayer.
|
// Single query at a time, so pass a slice of the KV cache because
|
||||||
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
// GemmaAttention will only access the first.
|
||||||
EmbedToken<TConfig>(tokens[batch_start + batch_idx], batch_idx,
|
const size_t kPrefillQueries = 1;
|
||||||
pos_per_query, weights, activations_[thread]);
|
KVCaches prefill_kv_caches(&kv_caches[qi], kPrefillQueries);
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
// For each batch of tokens in the query:
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||||
TransformerLayer<TConfig>(
|
tbatch_start += max_tbatch_size) {
|
||||||
num_tokens, num_queries, pos_per_query, layer, layer_weights,
|
// Fill activations.x (much faster than TransformerLayer).
|
||||||
activations_[thread], kv_caches, *inner_pools_[thread]);
|
const size_t tbatch_size =
|
||||||
}
|
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||||
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
|
const int token = prompts[qi][tbatch_start + ti];
|
||||||
|
EmbedToken<TConfig>(token, ti, pos + ti, weights, activations.x);
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
// Transformer with one batch of tokens from a single query.
|
||||||
for (size_t i = 0; i < batch_size; ++i) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const size_t query_idx = i % num_queries;
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
const size_t batch_idx = i / num_queries;
|
TransformerLayer<TConfig>(
|
||||||
runtime_config.StreamToken(query_idx, pos_per_query + batch_idx,
|
tbatch_size, kPrefillQueries, pos + tbatch_start, layer,
|
||||||
tokens[i], 0.0f);
|
layer_weights, activations, prefill_kv_caches, inner_pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
|
const int token = prompts[qi][tbatch_start + ti];
|
||||||
|
runtime_config.StreamToken(query_idx_start + qi,
|
||||||
|
pos + tbatch_start + ti, token, 0.0f);
|
||||||
|
}
|
||||||
|
} // for tbatch_start
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -663,39 +770,15 @@ class PrefillState {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Pins each outer thread after their inner threads so they are likely to
|
|
||||||
// run on the same socket.
|
|
||||||
void PinThreads(size_t outer_workers, size_t workers_per_outer) {
|
|
||||||
outer_pool_->Run(
|
|
||||||
0, outer_workers,
|
|
||||||
[this, workers_per_outer](uint64_t outer, size_t outer_thread) {
|
|
||||||
HWY_ASSERT(outer == outer_thread);
|
|
||||||
// Pins inner *and* `outer` - the latter is the calling thread.
|
|
||||||
inner_pools_[outer]->Run(
|
|
||||||
0, workers_per_outer,
|
|
||||||
[outer, workers_per_outer](uint64_t task, size_t thread) {
|
|
||||||
HWY_ASSERT(task == thread); // each worker has one task
|
|
||||||
const size_t lp = outer * workers_per_outer + task;
|
|
||||||
hwy::PinThreadToLogicalProcessor(lp);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void DeleteInnerPools() {
|
|
||||||
for (hwy::ThreadPool* p : inner_pools_) {
|
|
||||||
if (p != main_pool_) delete p;
|
|
||||||
}
|
|
||||||
inner_pools_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
hwy::ThreadPool* main_pool_;
|
hwy::ThreadPool* main_pool_;
|
||||||
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
|
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
|
||||||
std::vector<Activations> activations_; // size == outer_pool->NumWorkers()
|
// Holds a single pointer equal to main_pool_, or new allocations; in either
|
||||||
// Either there is a single pointer equal to main_pool, or newly created pools
|
// case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
|
||||||
// that we own. The former case avoids thread creation overhead for prompts
|
// allocation overhead for the common case of a single query.
|
||||||
// that fit in a single batch.
|
|
||||||
std::vector<hwy::ThreadPool*> inner_pools_;
|
std::vector<hwy::ThreadPool*> inner_pools_;
|
||||||
size_t num_batches_ = 0;
|
|
||||||
|
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
|
||||||
|
std::vector<Activations> activations_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
||||||
|
|
@ -705,8 +788,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
size_t num_queries, size_t pos,
|
size_t num_queries, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& pool,
|
|
||||||
const LayersOutputFunc& layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
|
|
@ -718,7 +800,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
||||||
activations);
|
activations.x);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
|
@ -781,10 +863,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
|
|
||||||
// Returns interleaved tokens: one from each query, followed by the second from
|
// Returns interleaved tokens: one from each query, followed by the second from
|
||||||
// all queries, with EOS padding.
|
// all queries, with EOS padding.
|
||||||
static std::vector<int> InterleaveQueries(
|
static std::vector<int> InterleaveQueries(const MultiplePromptsTokens& queries,
|
||||||
const hwy::Span<const hwy::Span<int>>& queries,
|
const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config, size_t& min_prompt_size,
|
size_t& min_prompt_size,
|
||||||
size_t& max_prompt_size) {
|
size_t& max_prompt_size) {
|
||||||
const size_t num_queries = queries.size();
|
const size_t num_queries = queries.size();
|
||||||
min_prompt_size = hwy::LimitsMax<size_t>();
|
min_prompt_size = hwy::LimitsMax<size_t>();
|
||||||
max_prompt_size = 0;
|
max_prompt_size = 0;
|
||||||
|
|
@ -829,28 +911,34 @@ class TokenStreamer {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const RuntimeConfig& runtime_config_;
|
const RuntimeConfig& runtime_config_;
|
||||||
// BitSet4096 divides the arg by 64, so ensure it is at least 64.
|
hwy::BitSet4096<> is_eos_;
|
||||||
hwy::BitSet4096<HWY_MAX(64, kBatchedQueryBatchSize)> is_eos_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generates one token per query in the batch.
|
// Generates one token for each query in `prompts`, which is one qbatch whose
|
||||||
|
// size is at most the `batch_size` passed to `activations.Allocate`.
|
||||||
//
|
//
|
||||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
// `pos` indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
||||||
// continues to increase by one for each prefilled/generated token per query.
|
// continues to increase by one for each prefilled/generated token per query.
|
||||||
// query_idx_start is the first query index in the batch.
|
//
|
||||||
template <class TConfig, size_t kQueryBatchSize>
|
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
||||||
|
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||||
|
//
|
||||||
|
// `kv_caches` is for the batch, size must match `prompts`.
|
||||||
|
template <class TConfig>
|
||||||
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, const size_t pos,
|
const MultiplePromptsTokens& prompts, const size_t pos,
|
||||||
const size_t query_idx_start,
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
|
||||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
const CompressedWeights<TConfig>& weights =
|
const CompressedWeights<TConfig>& weights =
|
||||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
|
|
||||||
const size_t num_queries = prompts.size();
|
const size_t num_queries = prompts.size();
|
||||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||||
|
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||||
|
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||||
|
|
||||||
size_t min_prompt_size, max_prompt_size;
|
size_t min_prompt_size, max_prompt_size;
|
||||||
const std::vector<int> prompt = InterleaveQueries(
|
const std::vector<int> prompt = InterleaveQueries(
|
||||||
prompts, runtime_config, min_prompt_size, max_prompt_size);
|
prompts, runtime_config, min_prompt_size, max_prompt_size);
|
||||||
|
|
@ -877,28 +965,28 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||||
// the first input token for generation.
|
// the first input token for generation.
|
||||||
const size_t prefill_per_query = min_prompt_size - 1;
|
const size_t prefill_per_query = min_prompt_size - 1;
|
||||||
const hwy::Span<const int> prefill_tokens(prompt.data(),
|
double prefill_start;
|
||||||
prefill_per_query * num_queries);
|
{
|
||||||
PrefillState prefill(pool);
|
PrefillState prefill(pool, num_queries);
|
||||||
prefill.Init<TConfig>(prefill_tokens.size());
|
prefill.AllocateActivations<TConfig>(num_queries,
|
||||||
const double prefill_start = hwy::platform::Now();
|
runtime_config.prefill_tbatch_size);
|
||||||
size_t interleaved_pos = pos * num_queries;
|
prefill_start = hwy::platform::Now();
|
||||||
prefill.Prefill<TConfig>(prefill_tokens, num_queries, interleaved_pos,
|
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
||||||
weights, runtime_config, kv_caches);
|
weights, runtime_config, kv_caches);
|
||||||
interleaved_pos += prefill_tokens.size();
|
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||||
timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start);
|
prefill.ResumeMainSpinning();
|
||||||
|
}
|
||||||
|
|
||||||
prefill.ResumeMainSpinning();
|
size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
|
||||||
|
|
||||||
// Storage for the last generated token from each query, passed to the next
|
// Storage for the last generated token from each query, passed to the next
|
||||||
// Transformer() call.
|
// Transformer() call.
|
||||||
std::vector<int> gen_tokens(num_queries);
|
std::vector<int> gen_tokens(num_queries);
|
||||||
|
|
||||||
// Stream the last prompt token from each query and fill gen_tokens.
|
// Stream the last prompt token from each query and fill gen_tokens.
|
||||||
hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(),
|
|
||||||
num_queries * sizeof(prompt[0]));
|
|
||||||
TokenStreamer token_streamer(runtime_config);
|
TokenStreamer token_streamer(runtime_config);
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
|
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
||||||
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
||||||
gen_tokens[query_idx], 0.0f);
|
gen_tokens[query_idx], 0.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -940,42 +1028,49 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
timing_info.NotifyGenerateDone(gen_start);
|
timing_info.NotifyGenerateDone(gen_start);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: prompt should also be span, not a vector.
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations,
|
void GenerateSingleT(const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos,
|
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
const size_t num_queries = 1;
|
||||||
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()),
|
const size_t qbatch_start = 0;
|
||||||
prompt.size());
|
|
||||||
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1);
|
Activations activations;
|
||||||
// TODO: also span of kv_cache, or batching inside KVCache?
|
activations.Allocate<TConfig>(num_queries);
|
||||||
std::vector<KVCache*> kv_caches = {&kv_cache};
|
|
||||||
const size_t query_idx_start = 0;
|
const MultiplePromptsTokens prompts(&prompt, num_queries);
|
||||||
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
const KVCaches kv_caches{&kv_cache, num_queries};
|
||||||
weights_u8, activations, runtime_config, prompts, pos, query_idx_start,
|
|
||||||
kv_caches, pool, timing_info);
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
|
||||||
|
qbatch_start, kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
|
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const MultiplePromptsTokens& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
// Disable query batching for Griffin models.
|
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||||
constexpr size_t kQueryBatchSize =
|
// Griffin does not support query batching.
|
||||||
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
const size_t max_qbatch_size =
|
||||||
for (size_t query_idx_start = 0; query_idx_start < prompts.size();
|
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
||||||
query_idx_start += kQueryBatchSize) {
|
|
||||||
const size_t num_queries =
|
Activations activations;
|
||||||
std::min(prompts.size() - query_idx_start, kQueryBatchSize);
|
activations.Allocate<TConfig>(max_qbatch_size);
|
||||||
const hwy::Span<const hwy::Span<int>> query_batch(
|
|
||||||
prompts.data() + query_idx_start, num_queries);
|
const size_t num_queries = prompts.size();
|
||||||
GenerateT<TConfig, kQueryBatchSize>(weights_u8, activations, runtime_config,
|
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||||
query_batch, pos, query_idx_start,
|
qbatch_start += max_qbatch_size) {
|
||||||
kv_caches, pool, timing_info);
|
// Generate one batch of tokens from `qbatch_size` queries.
|
||||||
|
const size_t qbatch_size =
|
||||||
|
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||||
|
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
|
||||||
|
qbatch_size);
|
||||||
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||||
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||||
|
pos, qbatch_start, qbatch_kv, pool, timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -986,24 +1081,20 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
// These are extern functions defined by instantiations/*.cc, which include this
|
// These are extern functions defined by instantiations/*.cc, which include this
|
||||||
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
||||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
|
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||||
(weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool,
|
(weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info);
|
||||||
timing_info);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||||
(weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool,
|
(weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info);
|
||||||
timing_info);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -24,32 +24,19 @@
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <typename TConfig>
|
|
||||||
struct AllocateActivations {
|
|
||||||
void operator()(Activations& decode) const {
|
|
||||||
// TODO: this is wasted if we only have single-batch queries. Instead
|
|
||||||
// re-allocate when the query batch size is actually > 1?
|
|
||||||
decode.Allocate<TConfig>(kBatchedQueryBatchSize);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||||
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||||
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
|
|
@ -58,7 +45,6 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
weights_u8_ =
|
weights_u8_ =
|
||||||
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||||
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
|
|
@ -70,67 +56,64 @@ Gemma::~Gemma() {
|
||||||
// we shard them across multiple translation units in instantiations/*.cc.
|
// we shard them across multiple translation units in instantiations/*.cc.
|
||||||
// This declares the functions defined there. We use overloading because
|
// This declares the functions defined there. We use overloading because
|
||||||
// explicit instantiations are still too slow to compile.
|
// explicit instantiations are still too slow to compile.
|
||||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||||
extern void GenerateSingle( \
|
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
const RuntimeConfig& runtime_config, \
|
||||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
|
const PromptTokens& prompt, size_t pos, \
|
||||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
|
KVCache& kv_cache, hwy::ThreadPool& pool, \
|
||||||
TimingInfo& timing_info); \
|
TimingInfo& timing_info); \
|
||||||
extern void GenerateBatch( \
|
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
const RuntimeConfig& runtime_config, \
|
||||||
const RuntimeConfig& runtime_config, \
|
const MultiplePromptsTokens& prompts, size_t pos, \
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
const KVCaches& kv_caches, hwy::ThreadPool& pool, \
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
TimingInfo& timing_info);
|
||||||
TimingInfo& timing_info);
|
|
||||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||||
|
|
||||||
// Adapters to select from the above overloads via CallForModelAndWeight.
|
// Adapters to select from the above overloads via CallForModelAndWeight.
|
||||||
// TODO: gather all ByteStorageT into a type-erased model struct?
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateSingleT {
|
struct GenerateSingleT {
|
||||||
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
void operator()(const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
||||||
GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos,
|
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
|
||||||
kv_cache, pool, timing_info);
|
pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateBatchT {
|
struct GenerateBatchT {
|
||||||
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
void operator()(const ByteStorageT& weights_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const MultiplePromptsTokens& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos,
|
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
|
||||||
kv_caches, pool, timing_info);
|
kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const PromptTokens& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateSingleT>(
|
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||||
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt,
|
runtime_config, prompt, start_pos,
|
||||||
start_pos, kv_cache, pool_, timing_info);
|
kv_cache, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
const MultiplePromptsTokens& prompts,
|
||||||
size_t start_pos,
|
size_t start_pos, const KVCaches& kv_caches,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateBatchT>(
|
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
|
||||||
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts,
|
runtime_config, prompts, start_pos,
|
||||||
start_pos, kv_caches, pool_, timing_info);
|
kv_caches, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -67,6 +67,13 @@ struct RuntimeConfig {
|
||||||
|
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
|
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||||
|
// Max tokens per batch during prefill.
|
||||||
|
size_t prefill_tbatch_size = 32;
|
||||||
|
// Max queries per batch (one token from each) during decode.
|
||||||
|
size_t decode_qbatch_size = 16;
|
||||||
|
|
||||||
float temperature;
|
float temperature;
|
||||||
int verbosity;
|
int verbosity;
|
||||||
std::mt19937* gen;
|
std::mt19937* gen;
|
||||||
|
|
@ -105,6 +112,10 @@ struct TimingInfo {
|
||||||
size_t tokens_generated;
|
size_t tokens_generated;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using PromptTokens = hwy::Span<const int>;
|
||||||
|
using MultiplePromptsTokens = hwy::Span<const PromptTokens>;
|
||||||
|
using KVCaches = hwy::Span<KVCache>;
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||||
|
|
@ -118,25 +129,20 @@ class Gemma {
|
||||||
const ModelInfo& Info() const { return info_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||||
const Activations& Decode() const { return decode_; }
|
|
||||||
|
|
||||||
void Generate(const RuntimeConfig& runtime_config,
|
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);
|
||||||
KVCache& kv_cache, TimingInfo& timing_info);
|
|
||||||
|
|
||||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
const MultiplePromptsTokens& prompts, size_t start_pos,
|
||||||
size_t start_pos, const std::vector<KVCache*>& kv_caches,
|
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||||
TimingInfo& timing_info);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
|
|
||||||
GemmaTokenizer tokenizer_;
|
GemmaTokenizer tokenizer_;
|
||||||
// Type-erased so that this can be defined in the header, without requiring
|
// Type-erased so that this can be defined in the header.
|
||||||
// forwarding functions.
|
|
||||||
ByteStorageT weights_u8_;
|
ByteStorageT weights_u8_;
|
||||||
Activations decode_;
|
|
||||||
ModelInfo info_;
|
ModelInfo info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,16 @@ namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct CreateKVCache {
|
struct CreateKVCache {
|
||||||
KVCache operator()() const {
|
KVCache operator()(size_t prefill_tbatch_size) const {
|
||||||
KVCache kv_cache = {};
|
KVCache kv_cache = {};
|
||||||
|
|
||||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||||
if (size_cache_pos != 0) {
|
if (size_cache_pos != 0) {
|
||||||
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize);
|
// Allocate more so that prefill can always access one batch, even if
|
||||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
// near the end of the sequence.
|
||||||
|
kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size;
|
||||||
|
kv_cache.kv_cache =
|
||||||
|
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(patrickms): Add query batching support for Griffin.
|
// TODO(patrickms): Add query batching support for Griffin.
|
||||||
|
|
@ -58,10 +61,13 @@ struct CreateKVCache {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
KVCache KVCache::Create(Model model_type) {
|
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||||
|
// prefill at a time.
|
||||||
|
KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) {
|
||||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||||
// use TConfig::Weight.
|
// use TConfig::Weight.
|
||||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type,
|
||||||
|
prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,17 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/common.h" // Model
|
#include "gemma/common.h" // Model
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||||
|
|
||||||
|
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||||
|
|
||||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||||
|
|
@ -31,7 +35,7 @@ struct KVCache {
|
||||||
// kModelDim * kGriffinLayers
|
// kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||||
|
|
||||||
static KVCache Create(Model type);
|
static KVCache Create(Model type, size_t prefill_tbatch_size);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -145,14 +145,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
RuntimeConfig runtime_config = {
|
RuntimeConfig runtime_config = {
|
||||||
.max_tokens = args.max_tokens,
|
|
||||||
.max_generated_tokens = args.max_generated_tokens,
|
|
||||||
.temperature = args.temperature,
|
|
||||||
.verbosity = verbosity,
|
.verbosity = verbosity,
|
||||||
.gen = &gen,
|
.gen = &gen,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.accept_token = accept_token,
|
.accept_token = accept_token,
|
||||||
};
|
};
|
||||||
|
args.CopyTo(runtime_config);
|
||||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||||
|
|
@ -181,7 +179,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma model = CreateGemma(loader, pool);
|
Gemma model = CreateGemma(loader, pool);
|
||||||
KVCache kv_cache = KVCache::Create(model.Info().model);
|
KVCache kv_cache =
|
||||||
|
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
|
||||||
|
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
|
|
|
||||||
18
util/app.h
18
util/app.h
|
|
@ -248,6 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
|
size_t prefill_tbatch_size;
|
||||||
|
size_t decode_qbatch_size;
|
||||||
|
|
||||||
float temperature;
|
float temperature;
|
||||||
bool deterministic;
|
bool deterministic;
|
||||||
bool multiturn;
|
bool multiturn;
|
||||||
|
|
@ -272,6 +275,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{64},
|
||||||
|
"Prefill: max tokens per batch.");
|
||||||
|
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||||
|
"Decode: max queries per batch.");
|
||||||
|
|
||||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||||
visitor(deterministic, "deterministic", false,
|
visitor(deterministic, "deterministic", false,
|
||||||
"Make top-k sampling deterministic", 2);
|
"Make top-k sampling deterministic", 2);
|
||||||
|
|
@ -281,6 +289,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
" Default : 0 (conversation "
|
" Default : 0 (conversation "
|
||||||
"resets every turn)");
|
"resets every turn)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||||
|
runtime_config.max_tokens = max_tokens;
|
||||||
|
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||||
|
|
||||||
|
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||||
|
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||||
|
|
||||||
|
runtime_config.temperature = temperature;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue