mirror of https://github.com/google/gemma.cpp.git
Use more parallelism in attention block in prefill mode.
Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.
This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.
Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):
```
Prefill speed
Num threads BEFORE AFTER
32 61.76 t/s 65.08 t/s
64 89.46 t/s 98.62 t/s
```
This commit is contained in:
parent
6eeef2e2d9
commit
3d72f17261
353
gemma/gemma.cc
353
gemma/gemma.cc
|
|
@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
|
|||
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
||||
return CreateKVCache(
|
||||
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen,
|
||||
Config::kSeqLen + kPrefillBatchSize,
|
||||
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
Config::kModelDim,
|
||||
Config::kGriffinLayers * Config::kModelDim);
|
||||
|
|
@ -569,34 +569,39 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t batch_idx, size_t layer,
|
||||
size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr bool kAdd = true;
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
|
||||
// X / Y linear layers.
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||
/*out1=*/y, pool);
|
||||
Gelu(y, kModelDim);
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||
/*out1=*/y, pool);
|
||||
Gelu(y, kModelDim);
|
||||
}
|
||||
|
||||
// Conv1D.
|
||||
{
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||
|
|
@ -611,14 +616,15 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
||||
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
||||
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
}
|
||||
|
|
@ -628,68 +634,80 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
|
||||
// RGLRU
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT gate_x =
|
||||
activations.griffin_gate_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT a =
|
||||
activations.griffin_multiplier.data() + batch_offset;
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
constexpr size_t kHeadDim = kModelDim / kHeads;
|
||||
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (kHeads + head), x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
|
||||
head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin.a.data() + head_offset, fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||
auto a = hn::Exp(df, log_a);
|
||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
|
||||
if (pos == 0) {
|
||||
x_multiplier = hn::Set(df, 1.0);
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
constexpr size_t kHeadDim = kModelDim / kHeads;
|
||||
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (kHeads + head), x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
|
||||
head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin.a.data() + head_offset, fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||
auto a = hn::Exp(df, log_a);
|
||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
|
||||
if (pos == 0) {
|
||||
x_multiplier = hn::Set(df, 1.0);
|
||||
}
|
||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||
|
||||
// Join branches.
|
||||
auto yv = hn::Load(df, y + head_offset + i);
|
||||
auto pre_out = hn::Mul(yv, new_x);
|
||||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||
|
||||
// Join branches.
|
||||
auto yv = hn::Load(df, y + head_offset + i);
|
||||
auto pre_out = hn::Mul(yv, new_x);
|
||||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Final linear layer.
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_out_w, 0, x,
|
||||
layer_weights->griffin.linear_out_biases.data(),
|
||||
activations.even_odd.data(), out_ptr, pool);
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_out_w, 0, x,
|
||||
layer_weights->griffin.linear_out_biases.data(),
|
||||
activations.even_odd.data(), out_ptr, pool);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
const LayerT* layer_weights, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Attention");
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
|
||||
static constexpr size_t kCachePosSize =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
|
||||
|
|
@ -699,47 +717,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
size_t cache_pos = pos;
|
||||
size_t cache_num = pos + 1;
|
||||
if constexpr (TConfig::kUseLocalAttention) {
|
||||
cache_pos %= TConfig::kSeqLen;
|
||||
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
|
||||
}
|
||||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
|
||||
size_t thread) HWY_ATTR {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||
head * TConfig::kSeqLen +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
head * kSeqLen +
|
||||
batch_idx * kHeads * kSeqLen;
|
||||
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
|
||||
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
}
|
||||
Softmax(head_att, cache_num);
|
||||
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
||||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -747,74 +761,99 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
// Multi-Head Attention
|
||||
static_assert(TConfig::kInterleaveQKV);
|
||||
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
Attn(q, head, head * kQKVDim * 2, thread);
|
||||
});
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
|
||||
});
|
||||
} else {
|
||||
// Multi-Query Attention
|
||||
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), q, pool);
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
|
||||
cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize;
|
||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), q, pool);
|
||||
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
Attn(q + head * kQKVDim, head, 0, thread);
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
// rearranging the weights.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT layer_out =
|
||||
activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data(),
|
||||
activations.even_odd.data(), layer_out, pool);
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
float* HWY_RESTRICT head_out =
|
||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
att_out + head * kQKVDim,
|
||||
activations.even_odd.data(), head_out, pool);
|
||||
AddFrom(head_out, layer_out, kModelDim);
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
// rearranging the weights.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT layer_out =
|
||||
activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data(),
|
||||
activations.even_odd.data(), layer_out, pool);
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
float* HWY_RESTRICT head_out =
|
||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
att_out + head * kQKVDim,
|
||||
activations.even_odd.data(), head_out, pool);
|
||||
AddFrom(head_out, layer_out, kModelDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||
size_t batch_idx, const LayerT* layer_weights,
|
||||
size_t num_tokens, const LayerT* layer_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
||||
|
||||
{
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
||||
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
|
||||
|
|
@ -839,11 +878,15 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||
}
|
||||
|
||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(), even_odd,
|
||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0,
|
||||
activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(), even_odd,
|
||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
}
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
|
|
@ -898,24 +941,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
}
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
|
||||
size_t /*thread*/) HWY_ATTR {
|
||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
|
||||
});
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
}
|
||||
|
|
@ -957,16 +1002,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
|
||||
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
||||
pool);
|
||||
} else {
|
||||
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
|
||||
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
}
|
||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
|
||||
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||
if (layers_output != nullptr) {
|
||||
std::string block_name = "blocks." + std::to_string(layer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue