1.03-1.08x decode speedup: precompute Rope theta, fuse

Split attention into functions, move into class.
Fuse Rope and MulBy, allow non-in-place version to avoid copy from q to KV.
Sink if() into MaybeLogitsSoftCap.

PiperOrigin-RevId: 661168418
This commit is contained in:
Jan Wassenberg 2024-08-09 01:22:46 -07:00 committed by Copybara-Service
parent 27258b03e6
commit 2ebbe4076f
12 changed files with 443 additions and 278 deletions

View File

@ -455,6 +455,7 @@ cc_test(
":common",
":gemma_lib",
":ops",
":prompt",
":sampler",
"@googletest//:gtest_main",
"//compression:weights_raw",

View File

@ -27,6 +27,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/activations.h" // CreateInvTimescale
#include "gemma/common.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -166,13 +167,12 @@ static HWY_NOINLINE void InputEmbeddingVJP(
}
}
template <typename TConfig, template<typename> typename LayerT>
template <typename TConfig, template <typename> typename LayerT>
void LayerVJP(const LayerT<TConfig>& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad,
size_t num_tokens,
LayerT<TConfig>& grad,
ForwardLayer<float, TConfig>& backward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerT<TConfig>& grad, ForwardLayer<float, TConfig>& backward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
@ -279,7 +279,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(b_kv, kQKVDim, -pos);
Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos);
}
for (size_t head = 0; head < kHeads; ++head) {
@ -287,7 +287,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
MulByConst(kQueryScale, b_q, kQKVDim);
Rope(b_q, kQKVDim, -pos);
Rope(b_q, kQKVDim, inv_timescale.Const(), -pos);
}
}
@ -342,13 +342,14 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
}
}
template <typename TConfig, template<typename...> typename WeightsT,
template<typename> typename LayerT>
template <typename TConfig, template <typename...> typename WeightsT,
template <typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const WeightsT<TConfig>& weights,
const ForwardPass<float, TConfig>& forward,
WeightsT<TConfig>& grad,
ForwardPass<float, TConfig>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -398,9 +399,10 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
LayerVJP<TConfig, LayerT>(*weights.GetLayer(layer), forward.layers[layer],
next_layer_grad, num_tokens,
*grad.GetLayer(layer), backward.layers[layer],
inv_timescale, pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,

View File

@ -29,6 +29,7 @@
#include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "gemma/activations.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
@ -41,6 +42,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ByteStorageT& forward_u8,
ByteStorageT& grad_u8,
ByteStorageT& backward_u8,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
using TWeights = CompressedWeights<TConfig>;
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
@ -49,25 +51,24 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt, weights, forward, grad, backward, pool);
prompt, weights, forward, grad, backward, inv_timescale, pool);
}
void CrossEntropyLossBackwardPassT(Model model,
const Prompt& prompt,
void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad,
ByteStorageT& backward,
ByteStorageT& grad, ByteStorageT& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, grad, backward, pool);
prompt, weights, forward, grad, backward, inv_timescale, pool);
break;
case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, grad, backward, pool);
prompt, weights, forward, grad, backward, inv_timescale, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
@ -83,12 +84,14 @@ namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) {
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
model, prompt, weights, forward, grad, backward, pool);
model, prompt, weights, forward, grad, backward, inv_timescale, pool);
}
} // namespace gcpp

View File

@ -17,15 +17,18 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool);
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp

View File

@ -21,6 +21,7 @@
#include <array>
#include <complex>
#include <cstdlib> // std::abs
#include <random>
#include <vector>
@ -28,9 +29,11 @@
#include "backprop/backward_scalar.h"
#include "backprop/common_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/weights_raw.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -214,6 +217,8 @@ void TestEndToEnd() {
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale =
Activations::CreateInvTimescale<TestConfig>();
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0f, gen);
@ -223,14 +228,14 @@ void TestEndToEnd() {
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
pool);
inv_timescale, pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.clear();
CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
pool);
inv_timescale, pool);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {

View File

@ -24,6 +24,7 @@
#include <vector>
#include "backprop/activations.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/base.h"
@ -88,11 +89,11 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
return loss * scaling;
}
template <typename TConfig, template<typename> typename LayerT>
template <typename TConfig, template <typename> typename LayerT>
void ApplyForwardLayer(const LayerT<TConfig>& weights,
ForwardLayer<float, TConfig>& activations,
size_t num_tokens,
float* HWY_RESTRICT output,
size_t num_tokens, float* HWY_RESTRICT output,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
@ -117,14 +118,14 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, pos);
Rope(k, kQKVDim, inv_timescale.Const(), pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, pos);
Rope(q, kQKVDim, inv_timescale.Const(), pos);
MulByConst(kQueryScale, q, kQKVDim);
});
@ -222,12 +223,13 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
}
}
template <typename TConfig, template<typename...> typename WeightsT,
template<typename> typename LayerT>
template <typename TConfig, template <typename...> typename WeightsT,
template <typename> typename LayerT>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const WeightsT<TConfig>& weights,
ForwardPass<float, TConfig>& forward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -251,9 +253,9 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
float* HWY_RESTRICT output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyForwardLayer<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, pool);
ApplyForwardLayer<TConfig, LayerT>(*weights.GetLayer(layer),
forward.layers[layer], num_tokens,
output, inv_timescale, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),

View File

@ -17,6 +17,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -39,28 +40,31 @@ template <typename TConfig>
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ByteStorageT& weights_u8,
ByteStorageT& forward_u8,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const auto& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return
CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt.tokens, prompt.context_size, weights, forward, pool);
return CrossEntropyLossForwardPass<TConfig, CompressedWeights,
CompressedLayer>(
prompt.tokens, prompt.context_size, weights, forward, inv_timescale,
pool);
}
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
forward, pool);
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, inv_timescale, pool);
case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, pool);
prompt, weights, forward, inv_timescale, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -75,11 +79,13 @@ namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool) {
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
model, prompt, weights, forward, pool);
model, prompt, weights, forward, inv_timescale, pool);
}
} // namespace gcpp

View File

@ -17,14 +17,17 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool);
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp

View File

@ -26,6 +26,7 @@
#include "backprop/optimizer.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
@ -56,6 +57,9 @@ TEST(OptimizeTest, GradientDescent) {
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale =
Activations::CreateInvTimescale<ConfigGemmaTiny<float>>();
Gemma gemma(GemmaTokenizer(), info, pools);
const auto generate = [&](const std::vector<int>& prompt) {
@ -118,10 +122,10 @@ TEST(OptimizeTest, GradientDescent) {
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(info.model, prompt,
gemma.Weights(), forward, pool);
total_loss += CrossEntropyLossForwardPass(
info.model, prompt, gemma.Weights(), forward, inv_timescale, pool);
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
grad, backward, pool);
grad, backward, inv_timescale, pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;

View File

@ -18,6 +18,8 @@
#include <stddef.h>
#include <cmath>
#include "gemma/common.h" // kMaxThreads - TODO: remove
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_DASSERT
@ -54,6 +56,7 @@ class RowVectorBatch {
// For MatMul or other operations that process the entire batch at once.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
private:
@ -88,6 +91,9 @@ struct Activations {
RowVectorBatch<float> griffin_gate_x;
RowVectorBatch<float> griffin_multiplier;
// Rope
RowVectorBatch<float> inv_timescale;
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
// per-thread storage.
// TODO: remove once MatVec is gone.
@ -106,6 +112,21 @@ struct Activations {
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
}
template <class TConfig>
static RowVectorBatch<float> CreateInvTimescale() {
constexpr size_t kQKVDim = TConfig::kQKVDim;
const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim;
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents =
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably.
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents);
}
return inv_timescale;
}
template <class TConfig>
void Allocate(size_t batch_size) {
constexpr size_t kModelDim = TConfig::kModelDim;
@ -138,6 +159,8 @@ struct Activations {
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
}
inv_timescale = CreateInvTimescale<TConfig>();
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
}
};

View File

@ -196,216 +196,317 @@ HWY_NOINLINE void GriffinRecurrent(
}
}
template <class TConfig, typename T>
HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) {
constexpr size_t kQKVDim = TConfig::kQKVDim;
// PostQKType::Rope
Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
// Wrapper class; holds arguments in member variables to shorten call sites.
template <class TConfig>
HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
size_t num_queries, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const KVCaches& kv_caches,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(interleaved_start % num_queries == 0);
constexpr size_t kQKVDim = TConfig::kQKVDim;
constexpr size_t kQStride = Activations::QStride<TConfig>();
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
class GemmaAttention {
static constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
static constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kQStride = Activations::QStride<TConfig>();
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
HWY_ASSERT(num_queries <= kv_caches.size());
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
// The attention window usually starts at 0 unless unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, size_t layer) {
const size_t att_window_size = TConfig::kAttentionWindowSizes[layer];
return pos - std::min(att_window_size - 1, pos);
}
// Multi-Head Attention a.k.a. "use_qkv_einsum".
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
const size_t batch_start = interleaved_start / num_queries;
const size_t num_interleaved = num_tokens * num_queries;
// For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
//
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below.
MatMul_4x4</*kAdd=*/false>(
num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim),
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(activations.q.All(), kHeads * kQStride), pool);
// Compute KV if not MHA.
if constexpr (!kIsMHA) {
// Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of kCachePosSize.
if (num_queries == 1 &&
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
const size_t kv_ofs =
batch_start * kCachePosSize + layer * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>(
num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim,
kHeads * kQKVDim * kModelDim),
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool);
template <typename T>
HWY_INLINE void PositionalEncodingQK(const T* qk, size_t pos, size_t layer,
const float mul, T* qk_out) {
const float* inv_timescale = activations_.inv_timescale.Const();
// PostQKType::Rope
(void)layer;
if (TConfig::kUseHalfRope) {
hwy::CopyBytes(qk, qk_out, kQKVDim * sizeof(*qk));
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, kQKVDim);
} else {
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
KVCache& kv_cache = kv_caches[query_idx];
const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
}
}
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
// KV directly to KVCache.
HWY_NOINLINE void ComputeQKV(const size_t batch_start,
const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.QKV");
// For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
const auto pre_att_rms_out =
MakeMat(activations_.pre_att_rms_out.All(), kModelDim);
MatMul_4x4</*kAdd=*/false>(
num_interleaved, pre_att_rms_out,
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(activations_.q.All(), kHeads * kQStride), pool_);
if constexpr (kIsMHA) {
static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved");
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else {
// Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of kCachePosSize.
if (num_queries_ == 1 &&
batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) {
const size_t kv_ofs =
batch_start * kCachePosSize + layer_ * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.All(), kv, pool);
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>(
num_tokens_, pre_att_rms_out,
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
kHeads * kQKVDim * kModelDim),
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_);
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx];
const size_t cache_pos =
div_seq_len_.Remainder(batch_start + batch_idx);
const size_t kv_offset =
cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations_.even_odd.All(), kv, pool_);
}
}
}
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(
0, kKVHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kKVHeads;
const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * kCachePosSize +
layer_ * kCacheLayerSize +
head * kQKVDim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * kQStride + kQKVDim;
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(kIsMHA ? mha_kv : kv, pos, layer_, 1.0f, kv);
// If MHA, also copy V into KVCache.
if (kIsMHA) {
hwy::CopyBytes(mha_kv + kQKVDim, kv + kQKVDim,
kQKVDim * sizeof(*kv));
}
});
}
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
if (HWY_LIKELY(pos <= kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t kv_offset =
pos2 * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim);
head_att[pos2] = score;
}
} else {
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len_.Remainder(pos2);
const size_t kv_offset =
cache_pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim);
head_att[pos2 % kSeqLen] = score;
}
}
}
// Apply positional encodings for K (and copy KV to cache if MHA).
pool.Run(
0, kKVHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kKVHeads;
const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = batch_start + batch_idx;
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`.
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
}
PostQK<TConfig>(kv, pos, layer);
});
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t pos,
const float* HWY_RESTRICT head_att,
const size_t layer,
const size_t head_offset,
const hwy::Divisor& div_seq_len,
const KVCache& kv_cache,
float* HWY_RESTRICT att_out) {
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
// A "head group" in the context of GQA refers to a collection of query heads
// that share the same key and value heads.
static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
constexpr size_t kHeadGroups = kHeads / kKVHeads;
// For each head (token, query), compute Q.K, softmax, and weighted V.
pool.Run(
0, kHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kHeads;
const size_t interleaved_idx = task / kHeads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride;
// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx;
PostQK<TConfig>(q, pos, layer);
MulByConst(kQueryScale, q, kQKVDim);
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
float* HWY_RESTRICT head_att =
activations.att.Batch(interleaved_idx) + head * kSeqLen;
// Usually start_pos is 0, unless pos is larger than the attention
// window size, then it is pos - window_size + 1.
const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len.Remainder(pos2);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim);
head_att[pos2 % kSeqLen] = score;
}
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in
// head_att.
const size_t head_att_len = std::min(pos + 1, kSeqLen);
if constexpr (TConfig::kAttCap > 0.0f) {
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
}
Softmax(head_att, head_att_len);
// Summation of v (kv_cache) weighted by probs (head_att)
// into "encoded" (att_out). Compare gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
float* HWY_RESTRICT att_out =
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len.Remainder(pos2);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
}
});
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
// into output (layer_out). Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx);
float* HWY_RESTRICT layer_out =
activations.att_post2.Batch(interleaved_idx);
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
const float* bias =
kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr;
MatVecT<kAdd, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out, bias,
activations.even_odd.All(), layer_out, pool);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
// NOTE: this is a single kModelDim temp output. If parallelized or using
// MatMul, add per-thread storage.
float* HWY_RESTRICT head_out = activations.att_post1.All();
// TODO: requires MatMul support for offsets.
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
if (HWY_LIKELY(pos <= kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t kv_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2], v, att_out, kQKVDim);
}
} else {
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len.Remainder(pos2);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
}
}
}
}
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start,
const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.DotSoftmax");
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
constexpr size_t kHeadGroups = kHeads / kKVHeads;
// For each head (token, query), compute Q.K, softmax, and weighted V.
pool_.Run(0, kHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kHeads;
const size_t interleaved_idx = task / kHeads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * kQStride;
// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx;
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
const size_t start_pos = StartPos(pos, layer_);
float* HWY_RESTRICT head_att =
activations_.att.Batch(interleaved_idx) + head * kSeqLen;
QDotK(start_pos, pos, head_offset, q, kv_cache, head_att);
// SoftMax with optional SoftCap yields "probabilities" in
// head_att.
const size_t head_att_len = std::min(pos + 1, kSeqLen);
MaybeLogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
Softmax(head_att, head_att_len);
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * kQKVDim;
WeightedSumV(start_pos, pos, head_att, layer_, head_offset,
div_seq_len_, kv_cache, att_out);
});
}
// Sums encoded (`att_out`) over num_heads and head_dim (kQKVDim) into output
// (`layer_out`). Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.SumHeads");
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx);
float* HWY_RESTRICT layer_out =
activations_.att_post2.Batch(interleaved_idx);
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
const float* bias =
kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr;
MatVecT<kAdd, kModelDim, kQKVDim>(
layer_weights_.attn_vec_einsum_w, 0, att_out, bias,
activations_.even_odd.All(), layer_out, pool_);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
// NOTE: this is a single kModelDim temp output. If parallelized or
// using MatMul, add per-thread storage.
float* HWY_RESTRICT head_out = activations_.att_post1.All();
// TODO: requires MatMul support for offsets.
MatVec<kModelDim, kQKVDim>(
layer_weights_.attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim, activations_.even_odd.All(), head_out,
pool_);
AddFrom(head_out, layer_out, kModelDim);
}
}
}
public:
GemmaAttention(size_t interleaved_start, size_t num_tokens,
size_t num_queries, size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool)
: interleaved_start_(interleaved_start),
num_tokens_(num_tokens),
num_queries_(num_queries),
layer_(layer),
activations_(activations),
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(pool) {
HWY_DASSERT(interleaved_start_ % num_queries_ == 0);
HWY_DASSERT(num_queries_ <= kv_caches_.size());
}
HWY_INLINE void operator()() {
const size_t batch_start = interleaved_start_ / num_queries_;
const size_t num_interleaved = num_tokens_ * num_queries_;
ComputeQKV(batch_start, num_interleaved);
DotSoftmaxWeightedSum(batch_start, num_interleaved);
SumHeads(num_interleaved);
}
private:
const size_t interleaved_start_;
const size_t num_tokens_;
const size_t num_queries_;
const size_t layer_;
Activations& activations_;
const CompressedLayer<TConfig>& layer_weights_;
const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_;
hwy::ThreadPool& pool_;
};
template <class TConfig>
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
size_t num_tokens, size_t num_queries, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
activations, layer_weights, kv_caches, pool);
activations, layer_weights, div_seq_len, kv_caches,
pool)();
} else {
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
@ -421,6 +522,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
template <class TConfig, typename T>
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
size_t count) {
PROFILER_ZONE("Gen.Activation");
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>;
using VF = hn::Vec<DF>;
@ -516,7 +618,8 @@ template <class TConfig>
HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool) {
constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_interleaved = num_tokens * num_queries;
auto type = TConfig::kLayerConfig[layer];
@ -528,7 +631,7 @@ HWY_NOINLINE void TransformerLayer(
activations.pre_att_rms_out.All(), kModelDim);
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
activations, layer_weights, kv_caches, pool);
activations, layer_weights, div_seq_len, kv_caches, pool);
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
activations.att_post2.All());
@ -606,6 +709,7 @@ class PrefillState {
const size_t query_idx_start,
const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, PerClusterPools& pools) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = prompts.size();
@ -638,9 +742,10 @@ class PrefillState {
// Transformer with one batch of tokens from a single query.
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(
tbatch_size, kPrefillQueries, pos + tbatch_start, layer,
layer_weights, activations, prefill_kv_caches, inner_pool);
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
pos + tbatch_start, layer,
layer_weights, activations, div_seq_len,
prefill_kv_caches, inner_pool);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
@ -664,6 +769,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations& activations,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) {
const size_t num_interleaved = num_tokens * num_queries;
@ -684,7 +790,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
layer_weights, activations, kv_caches, pool);
layer_weights, activations, div_seq_len,
kv_caches, pool);
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
@ -822,6 +929,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize());
HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t min_prompt_size, max_prompt_size;
const std::vector<int> prompt = InterleaveQueries(
@ -857,7 +965,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
pools);
prefill_start = hwy::platform::Now();
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
weights, runtime_config, kv_caches, pools);
weights, runtime_config, div_seq_len, kv_caches,
pools);
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
}
@ -881,8 +990,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
++gen_per_query) {
// Decode: generate one token for each query.
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
interleaved_pos, weights, activations, kv_caches, pool,
runtime_config.layers_output);
interleaved_pos, weights, activations, div_seq_len,
kv_caches, pool, runtime_config.layers_output);
interleaved_pos += num_queries;
bool all_queries_eos = true;
@ -895,9 +1004,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
MakeMat(activations.logits.All(), kVocabSize), pool);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
if constexpr (TConfig::kFinalCap > 0.0f) {
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
}
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
Softmax(logits, kVocabSize);
const int token = sample_token(logits, kVocabSize);
timing_info.NotifyGenerated(prefill_start, gen_start);

View File

@ -381,16 +381,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
of this rotation matrix which is simply the same matrix with -pos parameter)
*/
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
size_t dim_qkv, int pos) {
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
// This overload is called from backprop/ and if kUseHalfRope.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = StaticCast<float>(pos) / timescale;
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -400,24 +400,23 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
float* HWY_RESTRICT x,
size_t dim_qkv,
int pos) {
// TODO(janwas): vectorize
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = StaticCast<float>(pos) / timescale;
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
}
}
@ -577,12 +576,19 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
float* HWY_RESTRICT x,
const size_t size) {
static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size) {
LogitsSoftCap(cap, x, size, size);
}
// Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, float* HWY_RESTRICT x, const size_t size) {
if (cap != 0.0f) {
LogitsSoftCap(cap, x, size, size);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
SampleArgmax(const float* probabilities, size_t vocab_size) {
size_t max_index = 0;