mirror of https://github.com/google/gemma.cpp.git
Merge pull request #177 from szabadka:gemma2
PiperOrigin-RevId: 630388843
This commit is contained in:
commit
8ed22e52bf
|
|
@ -247,7 +247,6 @@ struct CompressTraits<hwy::bfloat16_t> {
|
||||||
using VU32 = hn::VFromD<decltype(du32)>;
|
using VU32 = hn::VFromD<decltype(du32)>;
|
||||||
const VU32 odd = Set(du32, 0xFFFF0000u);
|
const VU32 odd = Set(du32, 0xFFFF0000u);
|
||||||
|
|
||||||
VF32 be0, bo0, be1, bo1;
|
|
||||||
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||||
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||||
const VF32 ae0 = Load(df32, vec_aligned + i);
|
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||||
|
|
|
||||||
165
gemma/gemma.cc
165
gemma/gemma.cc
|
|
@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
|
||||||
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
||||||
return CreateKVCache(
|
return CreateKVCache(
|
||||||
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
|
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
|
||||||
Config::kSeqLen,
|
Config::kSeqLen + kPrefillBatchSize,
|
||||||
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||||
Config::kModelDim,
|
Config::kModelDim,
|
||||||
Config::kGriffinLayers * Config::kModelDim);
|
Config::kGriffinLayers * Config::kModelDim);
|
||||||
|
|
@ -569,22 +569,22 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||||
HWY_NOINLINE void GriffinRecurrent(
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
size_t batch_start, size_t batch_idx, size_t layer,
|
size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
HWY_DASSERT(batch_idx < kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kModelDim =
|
static constexpr size_t kModelDim =
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr bool kAdd = true;
|
static constexpr bool kAdd = true;
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
|
||||||
|
|
||||||
// X / Y linear layers.
|
// X / Y linear layers.
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||||
|
|
@ -594,9 +594,13 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||||
/*out1=*/y, pool);
|
/*out1=*/y, pool);
|
||||||
Gelu(y, kModelDim);
|
Gelu(y, kModelDim);
|
||||||
|
}
|
||||||
|
|
||||||
// Conv1D.
|
// Conv1D.
|
||||||
{
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
HWY_FULL(float) df;
|
HWY_FULL(float) df;
|
||||||
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
||||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||||
|
|
@ -611,7 +615,8 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||||
auto xv = hn::Load(df, x + i);
|
auto xv = hn::Load(df, x + i);
|
||||||
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
auto accum0 =
|
||||||
|
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||||
auto accum1 = hn::Zero(df);
|
auto accum1 = hn::Zero(df);
|
||||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||||
|
|
@ -628,8 +633,15 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// RGLRU
|
// RGLRU
|
||||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT gate_x =
|
||||||
|
activations.griffin_gate_x.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT a =
|
||||||
|
activations.griffin_multiplier.data() + batch_offset;
|
||||||
float* HWY_RESTRICT rnn_state =
|
float* HWY_RESTRICT rnn_state =
|
||||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||||
|
|
||||||
|
|
@ -673,23 +685,27 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
hn::Store(pre_out, df, x + head_offset + i);
|
hn::Store(pre_out, df, x + head_offset + i);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Final linear layer.
|
// Final linear layer.
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||||
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_out_w, 0, x,
|
layer_weights->griffin.linear_out_w, 0, x,
|
||||||
layer_weights->griffin.linear_out_biases.data(),
|
layer_weights->griffin.linear_out_biases.data(),
|
||||||
activations.even_odd.data(), out_ptr, pool);
|
activations.even_odd.data(), out_ptr, pool);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
const LayerT* layer_weights, KVCache& kv_cache,
|
const LayerT* layer_weights, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
const size_t pos = batch_start + batch_idx;
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
HWY_DASSERT(batch_idx < kBatchSize);
|
|
||||||
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
|
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
|
||||||
static constexpr size_t kCachePosSize =
|
static constexpr size_t kCachePosSize =
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
|
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
|
||||||
|
|
@ -699,47 +715,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
|
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
static const float kQueryScale =
|
static const float kQueryScale =
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
|
|
||||||
size_t cache_pos = pos;
|
auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
|
||||||
size_t cache_num = pos + 1;
|
|
||||||
if constexpr (TConfig::kUseLocalAttention) {
|
|
||||||
cache_pos %= TConfig::kSeqLen;
|
|
||||||
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
|
|
||||||
}
|
|
||||||
|
|
||||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
|
||||||
|
|
||||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
|
||||||
size_t thread) HWY_ATTR {
|
size_t thread) HWY_ATTR {
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
// Calculate scores
|
// Calculate scores
|
||||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||||
head * TConfig::kSeqLen +
|
head * kSeqLen +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kSeqLen;
|
||||||
|
|
||||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
||||||
const size_t cache_offset =
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
|
layer * kCacheLayerSize + head_offset;
|
||||||
|
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
Softmax(head_att, cache_num);
|
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
||||||
|
|
||||||
// Weighted summation
|
// Weighted summation
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
layer * kCacheLayerSize + head_offset;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
|
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -747,42 +759,66 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
// Multi-Head Attention
|
// Multi-Head Attention
|
||||||
static_assert(TConfig::kInterleaveQKV);
|
static_assert(TConfig::kInterleaveQKV);
|
||||||
|
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
float* HWY_RESTRICT qkv =
|
float* HWY_RESTRICT qkv =
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||||
pool);
|
pool);
|
||||||
|
}
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
const size_t num_tasks = kHeads * num_tokens;
|
||||||
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head = task % kHeads;
|
||||||
|
const size_t batch_idx = task / kHeads;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||||
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
Attn(q, head, head * kQKVDim * 2, thread);
|
});
|
||||||
|
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head = task % kHeads;
|
||||||
|
const size_t batch_idx = task / kHeads;
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||||
|
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Multi-Query Attention
|
// Multi-Query Attention
|
||||||
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||||
activations.even_odd.data(), q, pool);
|
activations.even_odd.data(), q, pool);
|
||||||
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
layer * kCacheLayerSize;
|
layer * kCacheLayerSize;
|
||||||
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||||
kHeads * kQKVDim * kModelDim, x,
|
kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.data(), kv, pool);
|
activations.even_odd.data(), kv, pool);
|
||||||
|
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
}
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
const size_t num_tasks = kHeads * num_tokens;
|
||||||
Attn(q + head * kQKVDim, head, 0, thread);
|
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head = task % kHeads;
|
||||||
|
const size_t batch_idx = task / kHeads;
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||||
|
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
// rearranging the weights.
|
// rearranging the weights.
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
|
|
@ -803,18 +839,19 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
AddFrom(head_out, layer_out, kModelDim);
|
AddFrom(head_out, layer_out, kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
||||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
size_t batch_idx, const LayerT* layer_weights,
|
size_t num_tokens, const LayerT* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_DASSERT(batch_idx < kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
|
||||||
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
||||||
|
|
||||||
{
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
||||||
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
|
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
|
@ -839,12 +876,16 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||||
|
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
layer_weights->linear_w, 0,
|
||||||
|
activations.ffw_hidden.data() + hidden_offset,
|
||||||
layer_weights->ffw_output_biases.data(), even_odd,
|
layer_weights->ffw_output_biases.data(), even_odd,
|
||||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||||
// are both constexpr
|
// are both constexpr
|
||||||
|
|
@ -898,24 +939,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
|
}
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
layer_weights, kv_cache, pool);
|
layer_weights, kv_cache, pool);
|
||||||
} else {
|
} else {
|
||||||
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
layer_weights, kv_cache, pool);
|
layer_weights, kv_cache, pool);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
size_t /*thread*/) HWY_ATTR {
|
||||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
|
});
|
||||||
|
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
}
|
}
|
||||||
|
|
@ -957,16 +1000,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
|
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
||||||
pool);
|
pool);
|
||||||
} else {
|
} else {
|
||||||
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
|
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
|
||||||
kv_cache, pool);
|
kv_cache, pool);
|
||||||
}
|
}
|
||||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||||
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
|
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
||||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
std::string block_name = "blocks." + std::to_string(layer);
|
std::string block_name = "blocks." + std::to_string(layer);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue