mirror of https://github.com/google/gemma.cpp.git
Cleanup: add wrapper functions and rename vars to interleaved
Simplifies the TransformerLayer function. Use interleaved* instead of _and_queries. PiperOrigin-RevId: 653929449
This commit is contained in:
parent
12016d31c3
commit
5844e6a1e5
|
|
@ -63,6 +63,11 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
// Different functions use different naming conventions for the number of
|
||||||
|
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||||
|
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||||
|
// `Attention`, use separate `num_tokens` and `num_queries`.
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void GriffinRecurrent(
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
|
|
@ -191,21 +196,21 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, typename T>
|
template <class TConfig, typename T>
|
||||||
HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) {
|
HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) {
|
||||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
size_t num_queries, size_t layer,
|
size_t num_queries, size_t layer,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
HWY_DASSERT(batch_and_query_start % num_queries == 0);
|
HWY_DASSERT(interleaved_start % num_queries == 0);
|
||||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
constexpr size_t kQStride = Activations::QStride<TConfig>();
|
constexpr size_t kQStride = Activations::QStride<TConfig>();
|
||||||
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
||||||
|
|
@ -218,8 +223,8 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||||
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||||
const size_t batch_start = batch_and_query_start / num_queries;
|
const size_t batch_start = interleaved_start / num_queries;
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
|
|
||||||
// For the computation of Q, K, and V, it is useful to remember that
|
// For the computation of Q, K, and V, it is useful to remember that
|
||||||
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
||||||
|
|
@ -229,16 +234,16 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
const float scale = layer_weights->qkv_einsum_w.scale();
|
const float scale = layer_weights->qkv_einsum_w.scale();
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||||
num_tokens_and_queries, activations.pre_att_rms_out.All(),
|
num_interleaved, activations.pre_att_rms_out.All(),
|
||||||
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
|
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
// Compute KV if not MHA.
|
||||||
if constexpr (!kIsMHA) {
|
if constexpr (!kIsMHA) {
|
||||||
for (size_t batch_and_query_idx = 0;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
++interleaved_idx) {
|
||||||
const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx);
|
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
||||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
|
|
@ -255,12 +260,12 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
|
|
||||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, kKVHeads * num_tokens_and_queries,
|
0, kKVHeads * num_interleaved,
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kKVHeads;
|
const size_t head = task % kKVHeads;
|
||||||
const size_t batch_and_query_idx = task / kKVHeads;
|
const size_t interleaved_idx = task / kKVHeads;
|
||||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
|
|
@ -270,7 +275,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
if constexpr (kIsMHA) {
|
if constexpr (kIsMHA) {
|
||||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
activations.q.Batch(interleaved_idx) + head * kQStride;
|
||||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||||
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
|
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
@ -281,16 +286,16 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
pool.Run(0, kHeads * num_tokens_and_queries,
|
pool.Run(
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t batch_and_query_idx = task / kHeads;
|
const size_t interleaved_idx = task / kHeads;
|
||||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
activations.q.Batch(interleaved_idx) + head * kQStride;
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
|
@ -299,7 +304,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
|
|
||||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations.att.Batch(batch_and_query_idx) + head * kSeqLen;
|
activations.att.Batch(interleaved_idx) + head * kSeqLen;
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
|
@ -311,7 +316,8 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
head_att[pos2 % kSeqLen] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att.
|
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in
|
||||||
|
// head_att.
|
||||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||||
if constexpr (TConfig::kAttCap > 0.0f) {
|
if constexpr (TConfig::kAttCap > 0.0f) {
|
||||||
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||||
|
|
@ -322,13 +328,14 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
// into "encoded" (att_out). Compare gemma/modules.py:
|
// into "encoded" (att_out). Compare gemma/modules.py:
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim;
|
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
float* HWY_RESTRICT v2 =
|
||||||
|
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -336,14 +343,13 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
|
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
|
||||||
// into output (layer_out). Compare gemma/modules.py:
|
// into output (layer_out). Compare gemma/modules.py:
|
||||||
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
||||||
for (size_t batch_and_query_idx = 0;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
++interleaved_idx) {
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
// rearranging the weights.
|
// rearranging the weights.
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx);
|
||||||
activations.att_out.Batch(batch_and_query_idx);
|
|
||||||
float* HWY_RESTRICT layer_out =
|
float* HWY_RESTRICT layer_out =
|
||||||
activations.att_post2.Batch(batch_and_query_idx);
|
activations.att_post2.Batch(interleaved_idx);
|
||||||
// Head 0 (and potentially biases) -> layer_out.
|
// Head 0 (and potentially biases) -> layer_out.
|
||||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
|
|
@ -364,6 +370,28 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
||||||
|
size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
|
Activations& activations,
|
||||||
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
|
const std::vector<KVCache*>& kv_caches,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
if (type == LayerAttentionType::kGemma) {
|
||||||
|
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
||||||
|
activations, layer_weights, kv_caches, pool);
|
||||||
|
} else {
|
||||||
|
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||||
|
// this code for non-Griffin models.
|
||||||
|
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||||
|
HWY_ASSERT(num_queries == 1);
|
||||||
|
GriffinRecurrent<TConfig>(interleaved_start, num_tokens, num_queries,
|
||||||
|
layer, activations, layer_weights, kv_caches,
|
||||||
|
pool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <class TConfig, typename T>
|
template <class TConfig, typename T>
|
||||||
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||||
size_t count) {
|
size_t count) {
|
||||||
|
|
@ -377,7 +405,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.FFW");
|
PROFILER_ZONE("Gen.FFW");
|
||||||
|
|
@ -388,7 +416,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
||||||
// elements in memory, repeated kFFHiddenDim times.
|
// elements in memory, repeated kFFHiddenDim times.
|
||||||
constexpr size_t kColsA = kModelDim;
|
constexpr size_t kColsA = kModelDim;
|
||||||
constexpr size_t kColsB = kFFHiddenDim;
|
constexpr size_t kColsB = kFFHiddenDim;
|
||||||
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
|
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||||
const float scale = layer_weights->gating_einsum_w.scale();
|
const float scale = layer_weights->gating_einsum_w.scale();
|
||||||
const auto B1 = layer_weights->gating_einsum_w.data();
|
const auto B1 = layer_weights->gating_einsum_w.data();
|
||||||
|
|
@ -400,18 +428,18 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
||||||
const auto bias2 = bias1 + kFFHiddenDim;
|
const auto bias2 = bias1 + kFFHiddenDim;
|
||||||
|
|
||||||
// Will go through GELU.
|
// Will go through GELU.
|
||||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, scale, C1,
|
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B1, scale,
|
||||||
bias1, pool);
|
C1, bias1, pool);
|
||||||
// What to multiply by.
|
// What to multiply by.
|
||||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, scale, C2,
|
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B2, scale,
|
||||||
bias2, pool);
|
C2, bias2, pool);
|
||||||
|
|
||||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
|
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
||||||
num_tokens, C1, layer_weights->linear_w.data(),
|
num_interleaved, C1, layer_weights->linear_w.data(),
|
||||||
layer_weights->linear_w.scale(), activations.ffw_out.All(),
|
layer_weights->linear_w.scale(), activations.ffw_out.All(),
|
||||||
layer_weights->ffw_output_biases.data_scale1(), pool);
|
layer_weights->ffw_output_biases.data_scale1(), pool);
|
||||||
}
|
}
|
||||||
|
|
@ -440,11 +468,18 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
|
|
||||||
template <class TConfig, typename T>
|
template <class TConfig, typename T>
|
||||||
HWY_NOINLINE void ResidualConnection(
|
HWY_NOINLINE void ResidualConnection(
|
||||||
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||||
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
|
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
// ResidualType::Add
|
// ResidualType::Add
|
||||||
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
|
AddFromBatched(num_interleaved, other, x, kModelDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TConfig, typename WeightT, typename InOutT>
|
||||||
|
void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) {
|
||||||
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
|
RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
|
|
@ -453,46 +488,37 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
|
||||||
|
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||||
layer_weights->pre_attention_norm_scale.data_scale1(),
|
layer_weights->pre_attention_norm_scale.data_scale1(),
|
||||||
activations.pre_att_rms_out.All(), kModelDim);
|
activations.pre_att_rms_out.All(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
|
||||||
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
|
||||||
layer_weights, kv_caches, pool);
|
|
||||||
} else {
|
|
||||||
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
|
||||||
// this code for non-Griffin models.
|
|
||||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
|
||||||
HWY_ASSERT(num_queries == 1);
|
|
||||||
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
|
|
||||||
activations, layer_weights, kv_caches, pool);
|
activations, layer_weights, kv_caches, pool);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
PostNorm<TConfig>(num_interleaved,
|
||||||
RMSNormInplaceBatched(
|
|
||||||
num_tokens_and_queries,
|
|
||||||
layer_weights->post_attention_norm_scale.data_scale1(),
|
layer_weights->post_attention_norm_scale.data_scale1(),
|
||||||
activations.att_post2.All(), kModelDim);
|
activations.att_post2.All());
|
||||||
}
|
|
||||||
|
|
||||||
ResidualConnection<TConfig>(num_tokens_and_queries,
|
ResidualConnection<TConfig>(num_interleaved, activations.att_post2.All(),
|
||||||
activations.att_post2.All(), activations.x.All(),
|
activations.x.All(), layer_weights,
|
||||||
layer_weights, /*is_attention=*/true);
|
/*is_attention=*/true);
|
||||||
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
|
||||||
|
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||||
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
||||||
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||||
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
|
||||||
RMSNormInplaceBatched(num_tokens_and_queries,
|
|
||||||
|
PostNorm<TConfig>(num_interleaved,
|
||||||
layer_weights->post_ffw_norm_scale.data_scale1(),
|
layer_weights->post_ffw_norm_scale.data_scale1(),
|
||||||
activations.ffw_out.All(), kModelDim);
|
activations.ffw_out.All());
|
||||||
}
|
|
||||||
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
|
ResidualConnection<TConfig>(num_interleaved, activations.ffw_out.All(),
|
||||||
activations.x.All(), layer_weights,
|
activations.x.All(), layer_weights,
|
||||||
/*is_attention=*/false);
|
/*is_attention=*/false);
|
||||||
}
|
}
|
||||||
|
|
@ -669,16 +695,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool,
|
||||||
const LayersOutputFunc& layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
++token_idx) {
|
|
||||||
float token_f = tokens[token_idx];
|
float token_f = tokens[token_idx];
|
||||||
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
||||||
activations);
|
activations);
|
||||||
}
|
}
|
||||||
|
|
@ -690,20 +715,17 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
const std::string block_name = "blocks." + std::to_string(layer);
|
const std::string block_name = "blocks." + std::to_string(layer);
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
++token_idx) {
|
|
||||||
layers_output(pos + token_idx, block_name,
|
layers_output(pos + token_idx, block_name,
|
||||||
activations.x.Batch(token_idx), kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched(num_tokens_and_queries,
|
RMSNormInplaceBatched(num_interleaved, weights.final_norm_scale.data_scale1(),
|
||||||
weights.final_norm_scale.data_scale1(),
|
|
||||||
activations.x.All(), kModelDim);
|
activations.x.All(), kModelDim);
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
++token_idx) {
|
|
||||||
layers_output(pos + token_idx, "final_norm",
|
layers_output(pos + token_idx, "final_norm",
|
||||||
activations.x.Batch(token_idx), kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -268,17 +268,24 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
||||||
GEMMA_CALL_FUNC("gr_a", griffin.a);
|
GEMMA_CALL_FUNC("gr_a", griffin.a);
|
||||||
}
|
}
|
||||||
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||||
|
|
||||||
|
// For conditionally-included tensors, the else branch must ensure their
|
||||||
|
// scale is initialized, because wrapper functions call data_scale1 even if
|
||||||
|
// the tensor turns out to be unused. If unused, the arrays are zero-length
|
||||||
|
// and data() returns a non-null but unusable pointer.
|
||||||
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
||||||
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
||||||
|
} else {
|
||||||
|
c_layer->post_attention_norm_scale.set_scale(1.0f);
|
||||||
|
c_layer->post_ffw_norm_scale.set_scale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TConfig::kFFBiases) {
|
if (TConfig::kFFBiases) {
|
||||||
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||||
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
|
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
|
||||||
} else {
|
} else {
|
||||||
// Ensure initialized so we can call data_scale1, which happens even if
|
|
||||||
// the tensor turns out to be unused.
|
|
||||||
c_layer->ffw_gating_biases.set_scale(1.0f);
|
c_layer->ffw_gating_biases.set_scale(1.0f);
|
||||||
c_layer->ffw_output_biases.set_scale(1.0f);
|
c_layer->ffw_output_biases.set_scale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -287,8 +294,6 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
||||||
if (TConfig::kSoftmaxAttnOutputBiases) {
|
if (TConfig::kSoftmaxAttnOutputBiases) {
|
||||||
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
||||||
} else {
|
} else {
|
||||||
// Ensure initialized so we can call data_scale1, which happens even if
|
|
||||||
// the tensor turns out to be unused.
|
|
||||||
c_layer->attention_output_biases.set_scale(1.0f);
|
c_layer->attention_output_biases.set_scale(1.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue