De-templatize Activations, add RowVectorBatch class

Also remove most kBatchSize args.

PiperOrigin-RevId: 653185525
This commit is contained in:
Jan Wassenberg 2024-07-17 04:37:40 -07:00 committed by Copybara-Service
parent ff34370aac
commit 992a2cbbc0
5 changed files with 305 additions and 293 deletions

View File

@ -18,77 +18,130 @@
#include <stddef.h>
#include <array>
#include "gemma/common.h" // AllocateSizeof
#include "hwy/base.h" // hwy::bfloat16_t
#include "gemma/common.h" // kMaxThreads - TODO: remove
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_DASSERT
namespace gcpp {
// Must be aligned.
template <class TConfig, size_t kBatchSize>
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x len) matrix.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() : batch_size_(0), len_(0) {}
// Main ctor, called from Activations::Allocate.
RowVectorBatch(size_t batch_size, size_t len)
: batch_size_(batch_size), len_(len) {
mem_ = hwy::AllocateAligned<T>(batch_size * len);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return batch_size_; }
size_t Len() const { return len_; }
// Returns the given row vector of length `Len()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < batch_size_);
return mem_.get() + batch_idx * len_;
}
// For MatMul or other operations that process the entire batch at once.
T* All() { return mem_.get(); }
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
size_t batch_size_; // rows in the matrix
size_t len_; // columns in the matrix = vector length
};
struct Activations {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
RowVectorBatch<float> x; // input
RowVectorBatch<float> q; // query, also KV if MHA.
RowVectorBatch<float> logits;
std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
std::array<float, kHeads * kBatchSize * kModelDim>
att_post1; // attention output after linear transformation, per head
std::array<float, kBatchSize * kModelDim>
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// Attention
RowVectorBatch<float> pre_att_rms_out;
RowVectorBatch<float> att; // attention vector
RowVectorBatch<float> att_out; // attention output
// After linear transformation, shared by all heads
RowVectorBatch<float> att_post1;
// Accumulation of attention outputs over heads
RowVectorBatch<float> att_post2;
// For FFW MatMul.
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
// Gated FFW
RowVectorBatch<hwy::bfloat16_t> bf_pre_ffw_rms_out;
RowVectorBatch<float> C1;
RowVectorBatch<float> C2;
RowVectorBatch<float> ffw_out;
std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin
RowVectorBatch<float> griffin_x;
RowVectorBatch<float> griffin_y;
RowVectorBatch<float> griffin_gate_x;
RowVectorBatch<float> griffin_multiplier;
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
// per-thread storage.
// TODO: only used for MatVec, remove once that is gone.
std::array<float, kModelDim * kMaxThreads> even_odd;
// TODO: remove once MatVec is gone.
RowVectorBatch<float> even_odd;
// Griffin layer internal activations
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
std::array<float, kBatchSize * kGriffinDim> griffin_x;
std::array<float, kBatchSize * kGriffinDim> griffin_y;
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
};
// Multi-Head Attention?
template <class TConfig>
static constexpr bool IsMHA() {
return TConfig::kHeads == TConfig::kKVHeads;
}
template <typename TConfig>
struct AllocateState {
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
// When batching queries, the prefill batch size is reduced by a factor
// of kBatchedQueryBatchSize
prefill =
AllocateSizeof<Activations<TConfig, kMinAdjustedPrefillBatchSize *
kBatchedQueryBatchSize>>();
decode = AllocateSizeof<
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
template <class TConfig>
static constexpr size_t QStride() {
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
}
template <class TConfig>
void Allocate(size_t batch_size) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kQKVDim = TConfig::kQKVDim;
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
constexpr size_t kSeqLen = TConfig::kSeqLen;
constexpr size_t kGriffinLayers = TConfig::kGriffinLayers;
x = RowVectorBatch<float>(batch_size, kModelDim);
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>());
logits = RowVectorBatch<float>(batch_size, kVocabSize);
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
att_post1 = RowVectorBatch<float>(1, kModelDim);
att_post2 = RowVectorBatch<float>(batch_size, kModelDim);
bf_pre_ffw_rms_out = RowVectorBatch<hwy::bfloat16_t>(batch_size, kModelDim);
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
if (kGriffinLayers > 0) {
griffin_x = RowVectorBatch<float>(batch_size, kModelDim);
griffin_y = RowVectorBatch<float>(batch_size, kModelDim);
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim);
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
}
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
}
};
template <class TConfig, size_t kBatchSize>
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_

View File

@ -30,6 +30,7 @@
#include <algorithm>
#include <string>
#include <type_traits>
#include <vector>
#include "gemma/activations.h"
@ -59,33 +60,27 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
template <class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
const CompressedLayer<TConfig>* layer_weights,
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
static_assert(kQueryBatchSize == 1,
"Griffin does not support batched queries.");
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
KVCache& kv_cache = *kv_caches[0];
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_ASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize * kQueryBatchSize>::kModelDim;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
// X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
TwoMatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
activations.pre_att_rms_out.Batch(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
@ -94,9 +89,8 @@ HWY_NOINLINE void GriffinRecurrent(
// Conv1D.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % hn::Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
@ -130,14 +124,11 @@ HWY_NOINLINE void GriffinRecurrent(
// RGLRU
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT gate_x =
activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a =
activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx);
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
@ -185,13 +176,12 @@ HWY_NOINLINE void GriffinRecurrent(
// Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* out_ptr = activations.att_post2.Batch(batch_idx);
MatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(),
activations.even_odd.data(), out_ptr, pool);
activations.even_odd.All(), out_ptr, pool);
}
}
@ -202,30 +192,26 @@ HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) {
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
HWY_NOINLINE void Attention(
size_t batch_and_query_start, size_t num_tokens, size_t num_queries,
size_t layer,
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) {
template <class TConfig>
HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
size_t num_queries, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(num_tokens <= kBatchSize);
HWY_DASSERT(num_queries <= kQueryBatchSize);
HWY_DASSERT(batch_and_query_start % num_queries == 0);
using TActivations = Activations<TConfig, kBatchSize * kQueryBatchSize>;
constexpr size_t kQKVDim = TActivations::kQKVDim;
constexpr size_t kQStride = TActivations::kQStride;
constexpr size_t kQKVDim = TConfig::kQKVDim;
constexpr size_t kQStride = Activations::QStride<TConfig>();
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
constexpr size_t kModelDim = TActivations::kModelDim;
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
// Multi-Head Attention a.k.a. "use_qkv_einsum".
constexpr bool kIsMHA = TActivations::kIsMHA;
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
const size_t batch_start = batch_and_query_start / num_queries;
const size_t num_tokens_and_queries = num_tokens * num_queries;
@ -237,15 +223,14 @@ HWY_NOINLINE void Attention(
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below.
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens_and_queries, activations.pre_att_rms_out.data(),
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
num_tokens_and_queries, activations.pre_att_rms_out.All(),
layer_weights->qkv_einsum_w.data(), activations.q.All(), pool);
// Compute KV if not MHA.
if constexpr (!kIsMHA) {
for (size_t batch_and_query_idx = 0;
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
const float* x =
activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim;
const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx);
const size_t query_idx = batch_and_query_idx % num_queries;
const size_t batch_idx = batch_and_query_idx / num_queries;
KVCache& kv_cache = *kv_caches[query_idx];
@ -258,7 +243,7 @@ HWY_NOINLINE void Attention(
// TODO: requires MatMul support for offsets.
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
activations.even_odd.All(), kv, pool);
}
}
@ -279,8 +264,7 @@ HWY_NOINLINE void Attention(
if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
activations.q.data() + (batch_and_query_idx * kHeads
+ head) * kQStride;
activations.q.Batch(batch_and_query_idx) + head * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
}
@ -300,7 +284,7 @@ HWY_NOINLINE void Attention(
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
KVCache& kv_cache = *kv_caches[query_idx];
float* HWY_RESTRICT q =
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
activations.q.Batch(batch_and_query_idx) + head * kQStride;
// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx;
@ -309,11 +293,7 @@ HWY_NOINLINE void Attention(
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
float* HWY_RESTRICT head_att =
activations.att.data() + head * kSeqLen
+ batch_and_query_idx * kHeads * kSeqLen;
// Compute Q dot K scores
activations.att.Batch(batch_and_query_idx) + head * kSeqLen;
const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
@ -335,8 +315,8 @@ HWY_NOINLINE void Attention(
// Summation of v (kv_cache) weighted by probs (head_att)
// into "encoded" (att_out). Compare gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_and_query_idx * kHeads * kQKVDim;
float* HWY_RESTRICT att_out =
activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
@ -355,26 +335,24 @@ HWY_NOINLINE void Attention(
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out =
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
activations.att_out.Batch(batch_and_query_idx);
float* HWY_RESTRICT layer_out =
activations.att_post2.data() + batch_and_query_idx * kModelDim;
activations.att_post2.Batch(batch_and_query_idx);
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(),
activations.even_odd.data(), layer_out, pool);
activations.even_odd.All(), layer_out, pool);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
// TODO(patrickms): Check this calculation
float* HWY_RESTRICT head_out =
activations.att_post1.data() +
head * kBatchSize * kQueryBatchSize * kModelDim;
// NOTE: this is a single kModelDim temp output. If parallelized or using
// MatMul, add per-thread storage.
float* HWY_RESTRICT head_out = activations.att_post1.All();
// TODO: requires MatMul support for offsets.
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim,
activations.even_odd.data(), head_out, pool);
att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
}
}
@ -392,13 +370,11 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
});
}
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t num_tokens,
template <class TConfig>
HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
const CompressedLayer<TConfig>* layer_weights,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.FFW");
HWY_DASSERT(num_tokens <= kBatchSize);
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
@ -406,11 +382,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// elements in memory, repeated kFFHiddenDim times.
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsB = kFFHiddenDim;
const auto A = activations.bf_pre_ffw_rms_out.data();
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All();
const auto B1 = layer_weights->gating_einsum_w.data();
const auto B2 = B1 + kColsA * kColsB;
auto C1 = activations.C1.data();
auto C2 = activations.C2.data();
auto C1 = activations.C1.All();
auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases;
const auto bias = layer_weights->ffw_gating_biases.data();
@ -422,8 +399,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
bias + kFFHiddenDim, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(activations.C1.data(), activations.C2.data(),
kFFHiddenDim * num_tokens);
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
// linear_w may have a scale value different from 1, apply that here.
// We multiply all activations by the scale value to compensate for the
@ -434,123 +410,115 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Hidden layer -> output layer.
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
num_tokens, C1, layer_weights->linear_w.data(),
activations.ffw_out.data(), layer_weights->ffw_output_biases.data(),
pool);
num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data(), pool);
}
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
// TODO: pass Activations.x instead of Activations.
template <class TConfig>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize>& activations) {
Activations& activations) {
constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < TConfig::kVocabSize);
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim);
activations.x.Batch(batch_idx), kModelDim);
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings(
activations.x.data() + token_idx * kModelDim, kModelDim,
pos + token_idx);
AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim,
pos + batch_idx);
};
}
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize, typename T>
template <class TConfig, typename T>
HWY_NOINLINE void ResidualConnection(
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
constexpr size_t kModelDim = TConfig::kModelDim;
// ResidualType::Add
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries, other, x,
kModelDim);
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
}
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
template <class TConfig, size_t kQueryBatchSize>
HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
const CompressedLayer<TConfig>* layer_weights,
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_tokens_and_queries = num_tokens * num_queries;
auto type = TConfig::kLayerConfig[layer];
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNormBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.x.data(),
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim);
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.All(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<TConfig, kBatchSize, kQueryBatchSize>(
pos, num_tokens, num_queries, layer_of_type, activations,
layer_weights, kv_caches, pool);
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
layer_weights, kv_caches, pool);
} else {
// This Griffin layers should never exist unless the model is a Griffin
// model. This conditional prevents the compiler from generating code for
// this branch when building a non-Griffin model, since we have static
// asserts about the query batch size for Griffin layers.
if constexpr (TConfig::kGriffinLayers > 0) {
GriffinRecurrent<TConfig, kBatchSize, kQueryBatchSize>(
pos, num_tokens, num_queries, layer_of_type, activations,
layer_weights, kv_caches, pool);
static_assert(kQueryBatchSize == 1,
"Griffin does not support batched queries.");
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
activations, layer_weights, kv_caches, pool);
}
}
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data(),
activations.att_post2.data(), kModelDim);
RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data(),
activations.att_post2.All(), kModelDim);
}
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
num_tokens_and_queries, activations.att_post2.data(),
activations.x.data(), layer_weights, /*is_attention*/ true);
RMSNormBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.x.data(),
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<TConfig, kBatchSize * kQueryBatchSize>(
activations, num_tokens_and_queries, layer_weights, pool);
ResidualConnection<TConfig>(num_tokens_and_queries,
activations.att_post2.All(), activations.x.All(),
layer_weights, /*is_attention=*/true);
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.All(), kModelDim);
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
activations.ffw_out.data(), kModelDim);
RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_ffw_norm_scale.data(),
activations.ffw_out.All(), kModelDim);
}
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
num_tokens_and_queries, activations.ffw_out.data(), activations.x.data(),
layer_weights, /*is_attention*/ false);
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
activations.x.All(), layer_weights,
/*is_attention=*/false);
}
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
HWY_NOINLINE void Prefill(
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations& activations,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Prefill");
HWY_DASSERT(num_queries <= kQueryBatchSize);
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
// TODO(patrickms): Try to hoist pool.Run out of the loop.
// TODO: hoist pool.Run out of the loop, change the unit of work to batches.
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
const size_t offset = i * num_queries;
const size_t current_token_count = std::min(
minibatch_size, num_tokens - i);
pool.Run(0, current_token_count * num_queries,
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
tokens[token_idx + offset], token_idx, pos + offset,
weights, activations);
});
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
EmbedToken<TConfig>(tokens[token_idx + offset], token_idx,
pos + offset, weights, activations);
});
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
current_token_count, num_queries, pos + offset , layer, layer_weights,
TransformerLayer<TConfig, kQueryBatchSize>(
current_token_count, num_queries, pos + offset, layer, layer_weights,
activations, kv_caches, pool);
}
}
@ -558,15 +526,14 @@ HWY_NOINLINE void Prefill(
// Compute the transformer for a batch of input tokens. During generation,
// we usually have num_tokens == 1 (and also kBatchSize == 1).
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
HWY_NOINLINE void Transformer(
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) {
HWY_ASSERT(num_tokens <= kBatchSize);
template <class TConfig, size_t kQueryBatchSize>
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations& activations,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) {
const size_t num_tokens_and_queries = num_tokens * num_queries;
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
@ -577,34 +544,33 @@ HWY_NOINLINE void Transformer(
}
constexpr size_t kModelDim = TConfig::kModelDim;
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) {
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
tokens[token_idx], token_idx, pos, weights, activations);
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
activations);
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
num_tokens, num_queries, pos, layer, layer_weights,
activations, kv_caches, pool);
TransformerLayer<TConfig, kQueryBatchSize>(num_tokens, num_queries, pos,
layer, layer_weights,
activations, kv_caches, pool);
if (layers_output) {
const std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
++token_idx) {
layers_output(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim, kModelDim);
activations.x.Batch(token_idx), kModelDim);
}
}
}
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens * num_queries, weights.final_norm_scale.data(),
activations.x.data(), kModelDim);
RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(),
activations.x.All(), kModelDim);
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
++token_idx) {
layers_output(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim);
activations.x.Batch(token_idx), kModelDim);
}
}
}
@ -644,9 +610,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
// Placeholder for internal test3, do not remove
template <class TConfig, size_t kQueryBatchSize>
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const size_t query_index_offset,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
@ -659,10 +624,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
pos *= num_queries; // position in (num_queries) interleaved token sequence.
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
auto& prefill_activations =
GetActivations<TConfig,
kAdjustedPrefillBatchSize * kQueryBatchSize>(prefill_u8);
auto& activations = GetActivations<TConfig, kQueryBatchSize>(decode_u8);
size_t min_prompt_size = (size_t)-1;
size_t max_prompt_size = 0;
@ -735,8 +696,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
batch_tokens, batch_size, num_queries, pos, weights,
prefill_activations, kv_caches, pool);
batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches,
pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
bool all_tokens_eos = true;
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
@ -788,7 +749,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
for (size_t generate_pos = 0;
generate_pos < max_tokens && generate_pos < max_generated_tokens;
++single_prompt_pos_offset, ++generate_pos) {
Transformer<TConfig, kDecodeBatchSize, kQueryBatchSize>(
Transformer<TConfig, kQueryBatchSize>(
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
activations, kv_caches, pool, runtime_config.layers_output);
float token_logit = 0.0f;
@ -796,22 +757,20 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
bool all_tokens_eos = true;
float* x = activations.x.data();
float* logits = activations.logits.data();
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset,
x += TConfig::kModelDim, logits += kVocabSize) {
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) {
const float* HWY_RESTRICT x = activations.x.Batch(i);
float* HWY_RESTRICT logits = activations.logits.Batch(i);
const size_t prompt_size = prompts[i].size();
const bool is_generating_phase =
(single_prompt_pos_offset >= prompt_size - 1);
if (is_generating_phase) {
PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations.
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
logits, pool);
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, x, activations.even_odd.All(),
logits, pool);
if constexpr (TConfig::kFinalCap > 0.0f) {
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
kVocabSize);
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
}
// Barrier: must have all logits so we can subtract max.
Softmax(logits, kVocabSize);
@ -850,9 +809,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
}
template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
@ -865,19 +823,17 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
std::vector<KVCache*> kv_caches = {&kv_cache};
const size_t query_index_offset = 0;
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos,
weights_u8, prefill, activations, runtime_config, prompts, pos,
query_index_offset, kv_caches, pool, timing_info);
}
template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations,
const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts,
size_t pos, const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool,
TimingInfo& timing_info) {
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
// Disable query batching for Griffin models.
constexpr size_t kQueryBatchSize =
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
@ -885,9 +841,9 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
const hwy::Span<const hwy::Span<int>> current_prompts(
prompts.data() + i, num_queries);
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill_u8, decode_u8,
runtime_config, current_prompts,
pos, i, kv_caches, pool, timing_info);
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill, activations,
runtime_config, current_prompts, pos, i,
kv_caches, pool, timing_info);
}
}
@ -898,25 +854,23 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
// These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, prefill_u8, decode_u8, runtime_config, prompt, pos, kv_cache,
(weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache,
pool, timing_info);
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, kv_caches,
(weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches,
pool, timing_info);
}

View File

@ -36,12 +36,23 @@
namespace gcpp {
template <typename TConfig>
struct AllocateState {
void operator()(Activations& prefill, Activations& decode) const {
// When batching queries, the prefill batch size is reduced by a factor
// of kBatchedQueryBatchSize
prefill.Allocate<TConfig>(kMinAdjustedPrefillBatchSize *
kBatchedQueryBatchSize);
decode.Allocate<TConfig>(kDecodeBatchSize * kBatchedQueryBatchSize);
}
};
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
decode_u8_);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
decode_);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
@ -50,8 +61,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ =
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
decode_u8_);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
decode_);
}
Gemma::~Gemma() {
@ -63,19 +74,17 @@ Gemma::~Gemma() {
// we shard them across multiple translation units in instantiations/*.cc.
// This declares the functions defined there. We use overloading because
// explicit instantiations are still too slow to compile.
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
extern void GenerateSingle( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
TimingInfo& timing_info); \
extern void GenerateBatch( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \
const RuntimeConfig& runtime_config, \
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
extern void GenerateSingle( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
Activations& decode, const RuntimeConfig& runtime_config, \
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, \
hwy::ThreadPool& pool, TimingInfo& timing_info); \
extern void GenerateBatch( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
Activations& decode, const RuntimeConfig& runtime_config, \
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
TimingInfo& timing_info);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
@ -83,25 +92,23 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
// TODO: gather all ByteStorageT into a type-erased model struct?
template <class TConfig>
struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
Activations& decode, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config,
GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config,
prompt, pos, kv_cache, pool, timing_info);
}
};
template <class TConfig>
struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8,
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
Activations& decode, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config,
GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config,
prompts, pos, kv_caches, pool, timing_info);
}
};
@ -112,8 +119,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateSingleT>(
info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_,
runtime_config, prompt, start_pos, kv_cache, pool_, timing_info);
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
prompt, start_pos, kv_cache, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
@ -126,8 +133,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateBatchT>(
info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_,
runtime_config, prompts, start_pos, kv_caches, pool_, timing_info);
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
prompts, start_pos, kv_caches, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

View File

@ -23,6 +23,7 @@
// IWYU pragma: begin_exports
#include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
@ -95,8 +96,8 @@ class Gemma {
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
const ByteStorageT& Decode() const { return decode_u8_; }
const Activations& Prefill() const { return prefill_; }
const Activations& Decode() const { return decode_; }
void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
@ -114,8 +115,8 @@ class Gemma {
// Type-erased so that this can be defined in the header, without requiring
// forwarding functions.
ByteStorageT weights_u8_;
ByteStorageT prefill_u8_;
ByteStorageT decode_u8_;
Activations prefill_;
Activations decode_;
ModelInfo info_;
};

View File

@ -1140,29 +1140,26 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
}
// Simple loops unless/until batch sizes are large enough to parallelize.
template <size_t kBatchSize, typename WeightT, typename OutT>
template <typename WeightT, typename OutT>
void RMSNormBatched(size_t num_tokens, const float* activations,
const WeightT* weights, OutT* out, const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations + token_idx * model_dim, weights,
out + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize, typename WeightT, typename InOutT>
// TODO: pass RowVectorBatch argument.
template <typename WeightT, typename InOutT>
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
InOutT* inout, const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize>
void AddFromBatched(size_t num_tokens, const float* other, float* x,
const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
float* x, const size_t model_dim) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
model_dim);