mirror of https://github.com/google/gemma.cpp.git
Add prompt batching to Gemma.cpp.
This CL adds a new function to Gemma that allows for batching of multiple prompts. The function takes a vector of prompts and returns a vector of responses. The prompts are processed in parallel, and the responses are returned in the same order as the prompts. PiperOrigin-RevId: 648367559
This commit is contained in:
parent
8ac5d66575
commit
da7507e6f0
13
BUILD.bazel
13
BUILD.bazel
|
|
@ -103,6 +103,10 @@ cc_library(
|
|||
"gemma/activations.h",
|
||||
"gemma/gemma.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
textual_hdrs = [
|
||||
# Placeholder for internal file1, do not remove,
|
||||
# Placeholder for internal file2, do not remove,
|
||||
|
|
@ -194,6 +198,7 @@ cc_test(
|
|||
":ops",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
@ -370,6 +375,10 @@ cc_test(
|
|||
"backprop/backward_test.cc",
|
||||
"backprop/test_util.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
|
|
@ -406,6 +415,10 @@ cc_test(
|
|||
srcs = [
|
||||
"backprop/optimize_test.cc",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":backprop",
|
||||
":common",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
|
|
@ -34,6 +36,7 @@
|
|||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -72,7 +75,11 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
} else {
|
||||
fprintf(stderr, "Loading model...\n");
|
||||
model_ = AllocateGemma(loader_, pool_);
|
||||
kv_cache_ = KVCache::Create(loader_.ModelType());
|
||||
|
||||
kv_caches_.reserve(16);
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType())));
|
||||
}
|
||||
}
|
||||
|
||||
InitGenerator(inference_args_, gen_);
|
||||
|
|
@ -107,8 +114,9 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
size_t total_tokens = 0;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
|
||||
int token, float) {
|
||||
const BatchStreamFunc batch_stream_token =
|
||||
[&res, &total_tokens, &time_start, this](
|
||||
size_t query_index, size_t pos, int token, float) {
|
||||
++total_tokens;
|
||||
res += StringFromTokens(std::vector<int>{token});
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
|
|
@ -123,8 +131,8 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
<< "\ttemperature: " << inference_args_.temperature << "\n";
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_,
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0],
|
||||
timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
|
@ -132,12 +140,73 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||
const hwy::Span<const hwy::Span<int>>& prompts) {
|
||||
std::vector<std::pair<std::string, size_t>> res(prompts.size());
|
||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
||||
size_t total_tokens = 0;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
const BatchStreamFunc batch_stream_token =
|
||||
[&res, &total_tokens, &time_start, this](
|
||||
size_t query_index, size_t pos, int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
// fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index,
|
||||
// token_text.c_str());
|
||||
std::string single_res = res[query_index].first + token_text;
|
||||
size_t current_token_count = res[query_index].second + 1;
|
||||
res[query_index] = std::make_pair(single_res, current_token_count);
|
||||
|
||||
++total_tokens;
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (app_.verbosity >= 2) {
|
||||
std::cout << inference_args_.max_tokens << " "
|
||||
<< inference_args_.max_generated_tokens << " "
|
||||
<< inference_args_.temperature;
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_,
|
||||
timing_info);
|
||||
if (app_.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
||||
const std::vector<int> prompt =
|
||||
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
|
||||
/*pos=*/0, input);
|
||||
return QueryModel(prompt);
|
||||
}
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||
const std::vector<std::string>& inputs) {
|
||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
||||
prompts.reserve(inputs.size());
|
||||
for (auto& input : inputs) {
|
||||
std::string mutable_prompt = input;
|
||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
||||
WrapAndTokenize(model_->Tokenizer(),
|
||||
loader_.ModelTrainingType(),
|
||||
/*pos=*/0, mutable_prompt)));
|
||||
}
|
||||
std::vector<hwy::Span<int>> prompt_vector;
|
||||
prompt_vector.reserve(prompts.size());
|
||||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(hwy::Span<int>(
|
||||
prompt->data(), prompt->size()));
|
||||
}
|
||||
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
|
||||
prompt_vector.data(), prompt_vector.size());
|
||||
return BatchQueryModel2(prompt_span);
|
||||
}
|
||||
|
||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
std::vector<int> prompt = Tokenize(input);
|
||||
|
|
|
|||
|
|
@ -69,8 +69,12 @@ class GemmaEnv {
|
|||
// Runs inference on the given input and returns the top-1 result string and
|
||||
// the number of tokens that were generated.
|
||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
||||
const hwy::Span<const hwy::Span<int>>& prompts);
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||
const std::vector<std::string>& inputs);
|
||||
|
||||
// Runs inference on the given input and returns the cross entropy, a measure
|
||||
// of how well the model predicts the correct output. It is the average
|
||||
|
|
@ -87,7 +91,7 @@ class GemmaEnv {
|
|||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return kv_cache_; }
|
||||
KVCache& MutableKVCache() { return *kv_caches_[0]; }
|
||||
|
||||
private:
|
||||
// Arguments to the model loader: file locations, etc.
|
||||
|
|
@ -103,7 +107,7 @@ class GemmaEnv {
|
|||
// The model to run inference on.
|
||||
std::unique_ptr<Gemma> model_;
|
||||
// The KV cache to use for inference.
|
||||
KVCache kv_cache_;
|
||||
std::vector<KVCache*> kv_caches_;
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
||||
|
|
|
|||
517
gemma/gemma.cc
517
gemma/gemma.cc
|
|
@ -75,23 +75,30 @@ struct Activations {
|
|||
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> x; // input
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * kQStride>
|
||||
q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||
att; // attention vector
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
||||
att; // attention vector
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim>
|
||||
att_out; // attention output
|
||||
std::array<float, kHeads * kBatchSize * kModelDim>
|
||||
att_post1; // attention output after linear transformation, per head
|
||||
std::array<float, kBatchSize * kModelDim>
|
||||
att_post2; // accumulation of attention outputs over heads
|
||||
std::array<hwy::bfloat16_t, kBatchSize * kModelDim>
|
||||
bf_pre_ffw_rms_out;
|
||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2>
|
||||
ffw_hidden;
|
||||
|
||||
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
|
||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
|
||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1; // MatMul output
|
||||
// For FFW MatMul.
|
||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
|
||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
|
||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||
|
||||
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
|
||||
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
|
||||
// bf_ffw_hidden;
|
||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||
|
||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||
|
|
@ -104,7 +111,8 @@ struct Activations {
|
|||
std::array<float, kBatchSize * kGriffinDim> griffin_x;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_y;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
|
||||
std::array<float, kBatchSize * kGriffinDim>
|
||||
griffin_multiplier;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
|
@ -116,10 +124,12 @@ struct CreateKVCache {
|
|||
|
||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||
if (size_cache_pos != 0) {
|
||||
const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize;
|
||||
const size_t seq_len =
|
||||
(TConfig::kSeqLen + kPrefillBatchSize);
|
||||
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (TConfig::kGriffinLayers) {
|
||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
const size_t conv1d_cache_size =
|
||||
|
|
@ -226,19 +236,24 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool) {
|
||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
static_assert(kQueryBatchSize == 1,
|
||||
"Griffin does not support batched queries.");
|
||||
HWY_DASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
||||
KVCache& kv_cache = *kv_caches[0];
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
constexpr size_t kModelDim = Activations<TConfig, kBatchSize>::kModelDim;
|
||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize * kQueryBatchSize>::kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
|
||||
// X / Y linear layers.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
|
|
@ -357,14 +372,19 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void Attention(
|
||||
size_t batch_and_query_start, size_t num_tokens, size_t num_queries,
|
||||
size_t layer,
|
||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Attention");
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
using TActivations = Activations<TConfig, kBatchSize>;
|
||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||
HWY_DASSERT(batch_and_query_start % num_queries == 0);
|
||||
using TActivations = Activations<TConfig, kBatchSize * kQueryBatchSize>;
|
||||
constexpr size_t kQKVDim = TActivations::kQKVDim;
|
||||
constexpr size_t kQStride = TActivations::kQStride;
|
||||
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
||||
|
|
@ -376,15 +396,22 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
||||
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
|
||||
const size_t batch_start = batch_and_query_start / num_queries;
|
||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||
|
||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||
num_tokens, activations.pre_att_rms_out.data(),
|
||||
num_tokens_and_queries, activations.pre_att_rms_out.data(),
|
||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
for (size_t batch_and_query_idx = 0;
|
||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||
const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx
|
||||
* kModelDim;
|
||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
// QKV projections:
|
||||
if constexpr (!kIsMHA) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
|
|
@ -401,18 +428,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
|
||||
// Positional encodings for kv:
|
||||
pool.Run(
|
||||
0, kKVHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
0, kKVHeads * num_tokens_and_queries,
|
||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kKVHeads;
|
||||
const size_t batch_idx = task / kKVHeads;
|
||||
const size_t batch_and_query_idx = task / kKVHeads;
|
||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
if constexpr (kIsMHA) {
|
||||
// For MHA, copy kv into the KV cache from scratch space (see above).
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQStride;
|
||||
activations.q.data() + (batch_and_query_idx * kHeads
|
||||
+ head) * kQStride;
|
||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
}
|
||||
|
|
@ -422,17 +454,22 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
static_assert((kHeads % kKVHeads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
pool.Run(0, kHeads * num_tokens_and_queries,
|
||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
const size_t batch_and_query_idx = task / kHeads;
|
||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
||||
KVCache& kv_cache = *kv_caches[query_idx];
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQStride;
|
||||
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
|
||||
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + head * kSeqLen + batch_idx * kHeads * kSeqLen;
|
||||
activations.att.data() + head * kSeqLen
|
||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
||||
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
|
@ -451,7 +488,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
batch_and_query_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
|
|
@ -462,20 +499,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
|||
}
|
||||
});
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
for (size_t batch_and_query_idx = 0;
|
||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
// rearranging the weights.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
|
||||
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT layer_out =
|
||||
activations.att_post2.data() + batch_idx * kModelDim;
|
||||
activations.att_post2.data() + batch_and_query_idx * kModelDim;
|
||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data(),
|
||||
activations.even_odd.data(), layer_out, pool);
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
// TODO(patrickms): Check this calculation
|
||||
float* HWY_RESTRICT head_out =
|
||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
activations.att_post1.data() +
|
||||
head * kBatchSize * kQueryBatchSize * kModelDim;
|
||||
// TODO: requires MatMul support for offsets.
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
|
|
@ -583,103 +623,138 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
|||
};
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void TransformerLayer(
|
||||
size_t num_tokens, size_t pos, size_t layer,
|
||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
Activations<TConfig, kBatchSize>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool) {
|
||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention(pos, num_tokens, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
Attention<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
pos, num_tokens, num_queries, layer_of_type, activations,
|
||||
layer_weights, kv_caches, pool);
|
||||
} else {
|
||||
GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
// This Griffin layers should never exist unless the model is a Griffin
|
||||
// model. This conditional prevents the compiler from generating code for
|
||||
// this branch when building a non-Griffin model, since we have static
|
||||
// asserts about the query batch size for Griffin layers.
|
||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||
GriffinRecurrent<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
pos, num_tokens, num_queries, layer_of_type, activations,
|
||||
layer_weights, kv_caches, pool);
|
||||
}
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries,
|
||||
layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW(activations, num_tokens, layer_weights, pool);
|
||||
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries,
|
||||
activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<TConfig, kBatchSize * kQueryBatchSize>(
|
||||
activations, num_tokens_and_queries, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(num_tokens,
|
||||
layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
AddFromBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void Prefill(
|
||||
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
// TODO(patrickms): Try to hoist pool.Run out of the loop.
|
||||
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
|
||||
const size_t offset = i * num_queries;
|
||||
const size_t current_token_count = std::min(
|
||||
minibatch_size, num_tokens - i);
|
||||
pool.Run(0, current_token_count * num_queries,
|
||||
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
|
||||
tokens[token_idx + offset], token_idx, pos + offset,
|
||||
weights, activations);
|
||||
});
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
||||
});
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
||||
kv_cache, pool);
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
current_token_count, num_queries, pos + offset , layer, layer_weights,
|
||||
activations, kv_caches, pool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the transformer for a batch of input tokens. During generation,
|
||||
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const LayersOutputFunc& layers_output) {
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void Transformer(
|
||||
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool,
|
||||
const LayersOutputFunc& layers_output) {
|
||||
HWY_ASSERT(num_tokens <= kBatchSize);
|
||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||
if (layers_output) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||
++token_idx) {
|
||||
float token_f = tokens[token_idx];
|
||||
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
||||
}
|
||||
}
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) {
|
||||
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
|
||||
tokens[token_idx], token_idx, pos, weights, activations);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
||||
kv_cache, pool);
|
||||
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
num_tokens, num_queries, pos, layer, layer_weights,
|
||||
activations, kv_caches, pool);
|
||||
|
||||
if (layers_output) {
|
||||
const std::string block_name = "blocks." + std::to_string(layer);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||
++token_idx) {
|
||||
layers_output(pos + token_idx, block_name,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens * num_queries, weights.final_norm_scale.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
if (layers_output) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||
++token_idx) {
|
||||
layers_output(pos + token_idx, "final_norm",
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
}
|
||||
|
|
@ -719,31 +794,69 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
|
||||
Activations<TConfig, kBatchSize>& GetActivations(
|
||||
const ByteStorageT& state_u8) {
|
||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(
|
||||
state_u8.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Placeholder for internal test3, do not remove
|
||||
|
||||
template <class TConfig>
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float weight,
|
||||
const RuntimeConfig& runtime_config) {
|
||||
if (runtime_config.batch_stream_token) {
|
||||
return runtime_config.batch_stream_token(query_idx, pos, token, weight);
|
||||
}
|
||||
return runtime_config.stream_token(token, weight);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kQueryBatchSize>
|
||||
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||
const ByteStorageT& decode_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||
const size_t query_index_offset,
|
||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kAdjustedPrefillBatchSize =
|
||||
std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize);
|
||||
static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize);
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||
pos *= num_queries; // position in (num_queries) interleaved token sequence.
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
auto& prefill_activations =
|
||||
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
||||
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
||||
GetActivations<TConfig,
|
||||
kAdjustedPrefillBatchSize * kQueryBatchSize>(prefill_u8);
|
||||
auto& activations = GetActivations<TConfig, kQueryBatchSize>(decode_u8);
|
||||
|
||||
size_t min_prompt_size = (size_t)-1;
|
||||
size_t max_prompt_size = 0;
|
||||
for (int i=0; i < prompts.size(); ++i) {
|
||||
min_prompt_size = std::min(min_prompt_size, prompts[i].size());
|
||||
max_prompt_size = std::max(max_prompt_size, prompts[i].size());
|
||||
}
|
||||
|
||||
std::vector<int> prompt;
|
||||
prompt.reserve(max_prompt_size * prompts.size());
|
||||
for (int i = 0; i < max_prompt_size; ++i) {
|
||||
for (int j=0; j < prompts.size(); ++j) {
|
||||
if (i < prompts[j].size()) {
|
||||
prompt.push_back(prompts[j][i]);
|
||||
} else {
|
||||
prompt.push_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
size_t prompt_size = prompt.size();
|
||||
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
||||
if (pos >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||
max_tokens);
|
||||
|
|
@ -760,6 +873,9 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
runtime_config.accept_token);
|
||||
};
|
||||
|
||||
std::vector<bool> reached_eos(num_queries);
|
||||
std::fill(reached_eos.begin(), reached_eos.end(), false);
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
//
|
||||
// After the first turn, pos gets passed in with > 0 corresponding to the
|
||||
|
|
@ -772,23 +888,44 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
||||
// always equal.
|
||||
size_t pos_offset = 0; // offset relative to pos
|
||||
// Used to keep track of how many tokens are processed per prompt,
|
||||
// so that we know when to start generating tokens.
|
||||
size_t single_prompt_pos_offset = 0;
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
|
||||
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
||||
// first input token for generation.
|
||||
while (pos_offset < prompt_size - 1) {
|
||||
const size_t batch_size =
|
||||
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
||||
while (single_prompt_pos_offset < min_prompt_size - 1) {
|
||||
const size_t batch_size = std::min(
|
||||
kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset);
|
||||
const size_t batch_and_query_size = batch_size * num_queries;
|
||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||
HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1);
|
||||
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
|
||||
const int* batch_tokens = prompt.data() + pos_offset;
|
||||
Prefill(batch_tokens, batch_size, pos, weights, prefill_activations,
|
||||
kv_cache, pool);
|
||||
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
|
||||
batch_tokens, batch_size, num_queries, pos, weights,
|
||||
prefill_activations, kv_caches, pool);
|
||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
|
||||
bool all_tokens_eos = true;
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
if (reached_eos[query_idx]) continue;
|
||||
if (StreamToken(
|
||||
query_idx + query_index_offset, single_prompt_pos_offset,
|
||||
batch_tokens[idx * num_queries + query_idx], 0.0f,
|
||||
runtime_config)) {
|
||||
all_tokens_eos = false;
|
||||
} else {
|
||||
reached_eos[query_idx] = true;
|
||||
}
|
||||
}
|
||||
if (all_tokens_eos) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
pos += batch_size;
|
||||
pos_offset += batch_size;
|
||||
pos += batch_and_query_size;
|
||||
pos_offset += batch_and_query_size;
|
||||
single_prompt_pos_offset += batch_size;
|
||||
}
|
||||
|
||||
timing_info.prefill_tok_sec =
|
||||
|
|
@ -796,45 +933,81 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
|
||||
// Start generation.
|
||||
const double gen_start = hwy::platform::Now();
|
||||
HWY_DASSERT(pos_offset == prompt_size - 1);
|
||||
HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1);
|
||||
size_t pos_gen_start = pos_offset;
|
||||
int token = prompt.at(pos_offset);
|
||||
// The loop below is not yet prepared for batch size > 1.
|
||||
std::vector<int>::const_iterator first = prompt.begin() + pos_offset;
|
||||
std::vector<int>::const_iterator last = first + num_queries;
|
||||
std::vector<int> gen_tokens(first, last);
|
||||
// The loop below is not yet prepared for decode batch size > 1.
|
||||
HWY_ASSERT(kDecodeBatchSize == 1);
|
||||
if (!runtime_config.stream_token(token, 0.0f)) return;
|
||||
bool all_tokens_eos = true;
|
||||
for (size_t i=0; i < num_queries; ++i) {
|
||||
if (reached_eos[i]) continue;
|
||||
if (StreamToken(i + query_index_offset,
|
||||
single_prompt_pos_offset, gen_tokens[i], 0.0f,
|
||||
runtime_config)) {
|
||||
all_tokens_eos = false;
|
||||
} else {
|
||||
reached_eos[i] = true;
|
||||
}
|
||||
}
|
||||
if (all_tokens_eos) {
|
||||
return;
|
||||
}
|
||||
for (size_t generate_pos = 0;
|
||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
Transformer(&token, kDecodeBatchSize, pos, weights, activations, kv_cache,
|
||||
pool, runtime_config.layers_output);
|
||||
generate_pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++single_prompt_pos_offset, ++generate_pos) {
|
||||
Transformer<TConfig, kDecodeBatchSize, kQueryBatchSize>(
|
||||
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
|
||||
activations, kv_caches, pool, runtime_config.layers_output);
|
||||
float token_logit = 0.0f;
|
||||
// The condition below is always true if we are doing Prefill above.
|
||||
// We keep it here for clarity so that the code is correct even if Prefill
|
||||
// is disabled.
|
||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||
if (is_generating_phase) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, activations.x.data(),
|
||||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
token = sample_token(activations.logits.data(), kVocabSize);
|
||||
token_logit = activations.logits[token];
|
||||
if (generate_pos == 0) {
|
||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||
bool all_tokens_eos = true;
|
||||
float* x = activations.x.data();
|
||||
float* logits = activations.logits.data();
|
||||
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset,
|
||||
x += TConfig::kModelDim, logits += kVocabSize) {
|
||||
const size_t prompt_size = prompts[i].size();
|
||||
const bool is_generating_phase =
|
||||
(single_prompt_pos_offset >= prompt_size - 1);
|
||||
if (is_generating_phase) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
|
||||
logits, pool);
|
||||
LogitsSoftCap(30.0f, logits, kVocabSize);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(logits, kVocabSize);
|
||||
token = sample_token(logits, kVocabSize);
|
||||
token_logit = logits[token];
|
||||
if (generate_pos == 0) {
|
||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||
}
|
||||
} else {
|
||||
// We would take this branch if we were not doing Prefill but would
|
||||
// process the tokens of the prompt one at a time.
|
||||
token = prompt.at(pos_offset);
|
||||
token_logit = 0.0f;
|
||||
}
|
||||
} else {
|
||||
// We would take this branch if we were not doing Prefill but would
|
||||
// process the tokens of the prompt one at a time.
|
||||
token = prompt.at(pos_offset + 1);
|
||||
|
||||
if (!reached_eos[i]) {
|
||||
if (!StreamToken(i + query_index_offset, single_prompt_pos_offset+1,
|
||||
token, token_logit, runtime_config)) {
|
||||
token = runtime_config.eos_id;
|
||||
}
|
||||
if (token != runtime_config.eos_id) {
|
||||
all_tokens_eos = false;
|
||||
} else {
|
||||
reached_eos[i] = true;
|
||||
}
|
||||
}
|
||||
gen_tokens[i] = token;
|
||||
}
|
||||
if (!runtime_config.stream_token(token, token_logit)) {
|
||||
token = runtime_config.eos_id;
|
||||
}
|
||||
if (token == runtime_config.eos_id) {
|
||||
if (all_tokens_eos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -842,6 +1015,46 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
(hwy::platform::Now() - gen_start);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateOneQueryT(const ByteStorageT& weights_u8,
|
||||
const ByteStorageT& prefill_u8,
|
||||
const ByteStorageT& decode_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
std::vector<hwy::Span<int>> prompt_vector = {
|
||||
hwy::Span<int>(const_cast<int*>(prompt.data()), prompt.size())};
|
||||
const hwy::Span<const hwy::Span<int>> prompts(
|
||||
prompt_vector.data(), prompt_vector.size());
|
||||
std::vector<KVCache*> kv_caches = {&kv_cache};
|
||||
GenerateT<TConfig, 1>(weights_u8, prefill_u8, decode_u8,
|
||||
runtime_config, prompts, pos, 0,
|
||||
kv_caches, pool, timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||
const ByteStorageT& prefill_u8,
|
||||
const ByteStorageT& decode_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
||||
size_t pos, const std::vector<KVCache*>& kv_caches,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
// Disable query batching for Griffin models.
|
||||
constexpr size_t kQueryBatchSize =
|
||||
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
||||
for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) {
|
||||
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
|
||||
const hwy::Span<const hwy::Span<int>> current_prompts(
|
||||
prompts.data() + i, num_queries);
|
||||
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill_u8, decode_u8,
|
||||
runtime_config, current_prompts,
|
||||
pos, i, kv_caches, pool, timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -853,8 +1066,13 @@ namespace {
|
|||
template <typename TConfig>
|
||||
struct AllocateState {
|
||||
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
||||
prefill = AllocateSizeof<Activations<TConfig, kPrefillBatchSize>>();
|
||||
decode = AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
|
||||
// When batching queries, the prefill batch size is reduced by a factor
|
||||
// of kBatchedQueryBatchSize
|
||||
prefill = AllocateSizeof<
|
||||
Activations<TConfig,
|
||||
kMinAdjustedPrefillBatchSize * kBatchedQueryBatchSize>>();
|
||||
decode = AllocateSizeof<
|
||||
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -895,13 +1113,28 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type_, weight_type_, GenerateT,
|
||||
model_type_, weight_type_, GenerateOneQueryT,
|
||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
||||
kv_cache, pool_, timing_info));
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
||||
size_t start_pos,
|
||||
const std::vector<KVCache*>& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type_, weight_type_, GenerateBatchT,
|
||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos,
|
||||
kv_caches, pool_, timing_info));
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelTraining training, size_t pos,
|
||||
std::string& prompt) {
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ namespace gcpp {
|
|||
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr size_t kDecodeBatchSize = 1;
|
||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||
constexpr size_t kMinAdjustedPrefillBatchSize =
|
||||
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
||||
constexpr bool kSystemPrompt = false;
|
||||
|
||||
struct KVCache {
|
||||
|
|
@ -72,6 +75,11 @@ class GemmaTokenizer {
|
|||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
|
|
@ -93,6 +101,7 @@ struct RuntimeConfig {
|
|||
int verbosity;
|
||||
std::mt19937* gen;
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
|
|
@ -125,6 +134,11 @@ class Gemma {
|
|||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info);
|
||||
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
||||
size_t start_pos, const std::vector<KVCache*>& kv_caches,
|
||||
TimingInfo& timing_info);
|
||||
|
||||
private:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "gemma/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
|
@ -44,13 +46,55 @@ class GemmaTest : public ::testing::Test {
|
|||
return response;
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
// Using the turn structure worsens results.
|
||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
||||
prompts.reserve(inputs.size());
|
||||
for (auto input_string : inputs) {
|
||||
std::string mutable_input_string = input_string;
|
||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
||||
s_env->TokenizeAndPrependBOS(input_string)));
|
||||
}
|
||||
std::vector<hwy::Span<int>> prompt_vector;
|
||||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(hwy::Span<int>(
|
||||
prompt->data(), prompt->size()));
|
||||
}
|
||||
hwy::Span<const hwy::Span<int>> prompt_span =
|
||||
hwy::Span<const hwy::Span<int>>(
|
||||
prompt_vector.data(), prompt_vector.size());
|
||||
std::vector<std::string> replies;
|
||||
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
|
||||
replies.push_back(response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||
if (!s_env->GetModel()) return;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
fprintf(stderr, "'%s'\n\n", response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
if (batch) {
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Batch Question %zu\n\n", i + 1);
|
||||
inputs.push_back(kQA[i][0]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::string response = responses.at(i);
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
fprintf(stderr, "'%s'\n\n", response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -58,10 +102,16 @@ class GemmaTest : public ::testing::Test {
|
|||
TEST_F(GemmaTest, Geography) {
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Hungary?", "Budapest"},
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
{"How many states does the US have?", "50"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
TestQuestions(kQA, kNum, /* batch= */ false);
|
||||
static const char* kQA_single_question[][2] = {
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
};
|
||||
TestQuestions(kQA_single_question, 1, /* batch= */ true);
|
||||
TestQuestions(kQA, kNum, /* batch= */ true);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, History) {
|
||||
|
|
@ -69,7 +119,7 @@ TEST_F(GemmaTest, History) {
|
|||
{"When was the battle of Hastings?", "1066"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
TestQuestions(kQA, kNum, /* batch= */ false);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Arithmetic) {
|
||||
|
|
@ -78,7 +128,7 @@ TEST_F(GemmaTest, Arithmetic) {
|
|||
{"what is 7 * 8?", "56"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
TestQuestions(kQA, kNum, /* batch= */ false);
|
||||
}
|
||||
|
||||
static const char kJingleBells[] = R"(
|
||||
|
|
|
|||
Loading…
Reference in New Issue