mirror of https://github.com/google/gemma.cpp.git
1.03-1.08x decode speedup: precompute Rope theta, fuse
Split attention into functions, move into class. Fuse Rope and MulBy, allow non-in-place version to avoid copy from q to KV. Sink if() into MaybeLogitsSoftCap. PiperOrigin-RevId: 661168418
This commit is contained in:
parent
27258b03e6
commit
2ebbe4076f
|
|
@ -455,6 +455,7 @@ cc_test(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:weights_raw",
|
"//compression:weights_raw",
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h" // CreateInvTimescale
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -166,13 +167,12 @@ static HWY_NOINLINE void InputEmbeddingVJP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, template<typename> typename LayerT>
|
template <typename TConfig, template <typename> typename LayerT>
|
||||||
void LayerVJP(const LayerT<TConfig>& weights,
|
void LayerVJP(const LayerT<TConfig>& weights,
|
||||||
const ForwardLayer<float, TConfig>& forward,
|
const ForwardLayer<float, TConfig>& forward,
|
||||||
const float* HWY_RESTRICT next_layer_grad,
|
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
|
||||||
size_t num_tokens,
|
LayerT<TConfig>& grad, ForwardLayer<float, TConfig>& backward,
|
||||||
LayerT<TConfig>& grad,
|
const RowVectorBatch<float>& inv_timescale,
|
||||||
ForwardLayer<float, TConfig>& backward,
|
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
|
|
@ -279,7 +279,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
|
||||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||||
float* HWY_RESTRICT b_kv =
|
float* HWY_RESTRICT b_kv =
|
||||||
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||||
Rope(b_kv, kQKVDim, -pos);
|
Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t head = 0; head < kHeads; ++head) {
|
for (size_t head = 0; head < kHeads; ++head) {
|
||||||
|
|
@ -287,7 +287,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
|
||||||
float* HWY_RESTRICT b_q =
|
float* HWY_RESTRICT b_q =
|
||||||
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||||
MulByConst(kQueryScale, b_q, kQKVDim);
|
MulByConst(kQueryScale, b_q, kQKVDim);
|
||||||
Rope(b_q, kQKVDim, -pos);
|
Rope(b_q, kQKVDim, inv_timescale.Const(), -pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -342,13 +342,14 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, template<typename...> typename WeightsT,
|
template <typename TConfig, template <typename...> typename WeightsT,
|
||||||
template<typename> typename LayerT>
|
template <typename> typename LayerT>
|
||||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
const WeightsT<TConfig>& weights,
|
const WeightsT<TConfig>& weights,
|
||||||
const ForwardPass<float, TConfig>& forward,
|
const ForwardPass<float, TConfig>& forward,
|
||||||
WeightsT<TConfig>& grad,
|
WeightsT<TConfig>& grad,
|
||||||
ForwardPass<float, TConfig>& backward,
|
ForwardPass<float, TConfig>& backward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -398,9 +399,10 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
float* next_layer_grad = layer + 1 < kLayers
|
float* next_layer_grad = layer + 1 < kLayers
|
||||||
? backward.layers[layer + 1].input.data()
|
? backward.layers[layer + 1].input.data()
|
||||||
: backward.final_layer_output.data();
|
: backward.final_layer_output.data();
|
||||||
LayerVJP<TConfig, LayerT>(
|
LayerVJP<TConfig, LayerT>(*weights.GetLayer(layer), forward.layers[layer],
|
||||||
*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
next_layer_grad, num_tokens,
|
||||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
|
*grad.GetLayer(layer), backward.layers[layer],
|
||||||
|
inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "backprop/backward-inl.h"
|
#include "backprop/backward-inl.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -41,6 +42,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
const ByteStorageT& forward_u8,
|
const ByteStorageT& forward_u8,
|
||||||
ByteStorageT& grad_u8,
|
ByteStorageT& grad_u8,
|
||||||
ByteStorageT& backward_u8,
|
ByteStorageT& backward_u8,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
using TWeights = CompressedWeights<TConfig>;
|
using TWeights = CompressedWeights<TConfig>;
|
||||||
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
||||||
|
|
@ -49,25 +51,24 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
||||||
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
||||||
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
|
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
|
||||||
prompt, weights, forward, grad, backward, pool);
|
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CrossEntropyLossBackwardPassT(Model model,
|
void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt,
|
||||||
const Prompt& prompt,
|
|
||||||
const ByteStorageT& weights,
|
const ByteStorageT& weights,
|
||||||
const ByteStorageT& forward,
|
const ByteStorageT& forward,
|
||||||
ByteStorageT& grad,
|
ByteStorageT& grad, ByteStorageT& backward,
|
||||||
ByteStorageT& backward,
|
RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
// TODO(janwas): use CallFunctorForModel
|
// TODO(janwas): use CallFunctorForModel
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
|
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
|
||||||
prompt, weights, forward, grad, backward, pool);
|
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_TINY:
|
case Model::GEMMA_TINY:
|
||||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
|
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
|
||||||
prompt, weights, forward, grad, backward, pool);
|
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
|
|
@ -83,12 +84,14 @@ namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
||||||
|
|
||||||
void CrossEntropyLossBackwardPass(
|
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
|
||||||
const Model& model, const Prompt& prompt,
|
const ByteStorageT& weights,
|
||||||
const ByteStorageT& weights, const ByteStorageT& forward,
|
const ByteStorageT& forward,
|
||||||
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) {
|
ByteStorageT& grad, ByteStorageT& backward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||||
model, prompt, weights, forward, grad, backward, pool);
|
model, prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,18 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void CrossEntropyLossBackwardPass(
|
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
|
||||||
const Model& model, const Prompt& prompt,
|
const ByteStorageT& weights,
|
||||||
const ByteStorageT& weights, const ByteStorageT& forward,
|
const ByteStorageT& forward,
|
||||||
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool);
|
ByteStorageT& grad, ByteStorageT& backward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
#include <cstdlib> // std::abs
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -28,9 +29,11 @@
|
||||||
#include "backprop/backward_scalar.h"
|
#include "backprop/backward_scalar.h"
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "compression/weights_raw.h"
|
#include "compression/weights_raw.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -214,6 +217,8 @@ void TestEndToEnd() {
|
||||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||||
|
|
||||||
|
RowVectorBatch<float> inv_timescale =
|
||||||
|
Activations::CreateInvTimescale<TestConfig>();
|
||||||
for (const Prompt& prompt : batch) {
|
for (const Prompt& prompt : batch) {
|
||||||
ReverseSequenceSampler::LogPrompt(prompt);
|
ReverseSequenceSampler::LogPrompt(prompt);
|
||||||
RandInit(weights.get(), 1.0f, gen);
|
RandInit(weights.get(), 1.0f, gen);
|
||||||
|
|
@ -223,14 +228,14 @@ void TestEndToEnd() {
|
||||||
|
|
||||||
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
|
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
|
||||||
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
|
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
|
||||||
pool);
|
inv_timescale, pool);
|
||||||
|
|
||||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||||
|
|
||||||
grad.clear();
|
grad.clear();
|
||||||
CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
|
CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
|
||||||
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
||||||
pool);
|
inv_timescale, pool);
|
||||||
|
|
||||||
Complexify(weights.get(), c_weights.get());
|
Complexify(weights.get(), c_weights.get());
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -88,11 +89,11 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
||||||
return loss * scaling;
|
return loss * scaling;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, template<typename> typename LayerT>
|
template <typename TConfig, template <typename> typename LayerT>
|
||||||
void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
ForwardLayer<float, TConfig>& activations,
|
ForwardLayer<float, TConfig>& activations,
|
||||||
size_t num_tokens,
|
size_t num_tokens, float* HWY_RESTRICT output,
|
||||||
float* HWY_RESTRICT output,
|
const RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
|
|
@ -117,14 +118,14 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
float* HWY_RESTRICT k =
|
float* HWY_RESTRICT k =
|
||||||
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||||
Rope(k, kQKVDim, pos);
|
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
||||||
}
|
}
|
||||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t pos = task / kHeads;
|
const size_t pos = task / kHeads;
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||||
Rope(q, kQKVDim, pos);
|
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -222,12 +223,13 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig, template<typename...> typename WeightsT,
|
template <typename TConfig, template <typename...> typename WeightsT,
|
||||||
template<typename> typename LayerT>
|
template <typename> typename LayerT>
|
||||||
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
size_t context_size,
|
size_t context_size,
|
||||||
const WeightsT<TConfig>& weights,
|
const WeightsT<TConfig>& weights,
|
||||||
ForwardPass<float, TConfig>& forward,
|
ForwardPass<float, TConfig>& forward,
|
||||||
|
const RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -251,9 +253,9 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
float* HWY_RESTRICT output = layer + 1 < kLayers ?
|
float* HWY_RESTRICT output = layer + 1 < kLayers ?
|
||||||
forward.layers[layer + 1].input.data() :
|
forward.layers[layer + 1].input.data() :
|
||||||
forward.final_layer_output.data();
|
forward.final_layer_output.data();
|
||||||
ApplyForwardLayer<TConfig, LayerT>(
|
ApplyForwardLayer<TConfig, LayerT>(*weights.GetLayer(layer),
|
||||||
*weights.GetLayer(layer), forward.layers[layer],
|
forward.layers[layer], num_tokens,
|
||||||
num_tokens, output, pool);
|
output, inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
ApplyRMSNorm(weights.final_norm_scale.data(),
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -39,28 +40,31 @@ template <typename TConfig>
|
||||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
const ByteStorageT& weights_u8,
|
const ByteStorageT& weights_u8,
|
||||||
ByteStorageT& forward_u8,
|
ByteStorageT& forward_u8,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
const auto& weights =
|
const auto& weights =
|
||||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
auto& forward =
|
auto& forward =
|
||||||
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
||||||
return
|
return CrossEntropyLossForwardPass<TConfig, CompressedWeights,
|
||||||
CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>(
|
CompressedLayer>(
|
||||||
prompt.tokens, prompt.context_size, weights, forward, pool);
|
prompt.tokens, prompt.context_size, weights, forward, inv_timescale,
|
||||||
|
pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
||||||
const ByteStorageT& weights,
|
const ByteStorageT& weights,
|
||||||
ByteStorageT& forward,
|
ByteStorageT& forward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
// TODO(janwas): use CallFunctorForModel
|
// TODO(janwas): use CallFunctorForModel
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
|
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(
|
||||||
forward, pool);
|
prompt, weights, forward, inv_timescale, pool);
|
||||||
case Model::GEMMA_TINY:
|
case Model::GEMMA_TINY:
|
||||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
|
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
|
||||||
prompt, weights, forward, pool);
|
prompt, weights, forward, inv_timescale, pool);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -75,11 +79,13 @@ namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
||||||
|
|
||||||
float CrossEntropyLossForwardPass(
|
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
|
||||||
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
|
const ByteStorageT& weights,
|
||||||
ByteStorageT& forward, hwy::ThreadPool& pool) {
|
ByteStorageT& forward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||||
model, prompt, weights, forward, pool);
|
model, prompt, weights, forward, inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,17 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
float CrossEntropyLossForwardPass(
|
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
|
||||||
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
|
const ByteStorageT& weights,
|
||||||
ByteStorageT& forward, hwy::ThreadPool& pool);
|
ByteStorageT& forward,
|
||||||
|
RowVectorBatch<float>& inv_timescale,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
#include "backprop/optimizer.h"
|
#include "backprop/optimizer.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
@ -56,6 +57,9 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||||
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
||||||
|
|
||||||
|
RowVectorBatch<float> inv_timescale =
|
||||||
|
Activations::CreateInvTimescale<ConfigGemmaTiny<float>>();
|
||||||
|
|
||||||
Gemma gemma(GemmaTokenizer(), info, pools);
|
Gemma gemma(GemmaTokenizer(), info, pools);
|
||||||
|
|
||||||
const auto generate = [&](const std::vector<int>& prompt) {
|
const auto generate = [&](const std::vector<int>& prompt) {
|
||||||
|
|
@ -118,10 +122,10 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
num_ok = 0;
|
num_ok = 0;
|
||||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||||
Prompt prompt = training_task.Sample(sgen);
|
Prompt prompt = training_task.Sample(sgen);
|
||||||
total_loss += CrossEntropyLossForwardPass(info.model, prompt,
|
total_loss += CrossEntropyLossForwardPass(
|
||||||
gemma.Weights(), forward, pool);
|
info.model, prompt, gemma.Weights(), forward, inv_timescale, pool);
|
||||||
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
|
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
|
||||||
grad, backward, pool);
|
grad, backward, inv_timescale, pool);
|
||||||
num_ok += verify(prompt) ? 1 : 0;
|
num_ok += verify(prompt) ? 1 : 0;
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
#include "gemma/common.h" // kMaxThreads - TODO: remove
|
#include "gemma/common.h" // kMaxThreads - TODO: remove
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // HWY_DASSERT
|
#include "hwy/base.h" // HWY_DASSERT
|
||||||
|
|
@ -54,6 +56,7 @@ class RowVectorBatch {
|
||||||
|
|
||||||
// For MatMul or other operations that process the entire batch at once.
|
// For MatMul or other operations that process the entire batch at once.
|
||||||
T* All() { return mem_.get(); }
|
T* All() { return mem_.get(); }
|
||||||
|
const T* Const() const { return mem_.get(); }
|
||||||
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -88,6 +91,9 @@ struct Activations {
|
||||||
RowVectorBatch<float> griffin_gate_x;
|
RowVectorBatch<float> griffin_gate_x;
|
||||||
RowVectorBatch<float> griffin_multiplier;
|
RowVectorBatch<float> griffin_multiplier;
|
||||||
|
|
||||||
|
// Rope
|
||||||
|
RowVectorBatch<float> inv_timescale;
|
||||||
|
|
||||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||||
// per-thread storage.
|
// per-thread storage.
|
||||||
// TODO: remove once MatVec is gone.
|
// TODO: remove once MatVec is gone.
|
||||||
|
|
@ -106,6 +112,21 @@ struct Activations {
|
||||||
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
|
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
static RowVectorBatch<float> CreateInvTimescale() {
|
||||||
|
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
|
const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim;
|
||||||
|
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
|
||||||
|
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||||
|
const float freq_exponents =
|
||||||
|
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
|
||||||
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||||
|
// noticeably.
|
||||||
|
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents);
|
||||||
|
}
|
||||||
|
return inv_timescale;
|
||||||
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void Allocate(size_t batch_size) {
|
void Allocate(size_t batch_size) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -138,6 +159,8 @@ struct Activations {
|
||||||
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
|
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inv_timescale = CreateInvTimescale<TConfig>();
|
||||||
|
|
||||||
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
|
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -196,216 +196,317 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, typename T>
|
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||||
HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) {
|
|
||||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
|
||||||
// PostQKType::Rope
|
|
||||||
Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
class GemmaAttention {
|
||||||
size_t num_queries, size_t layer,
|
static constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
|
||||||
Activations& activations,
|
static constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
const KVCaches& kv_caches,
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
hwy::ThreadPool& pool) {
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
PROFILER_ZONE("Gen.Attention");
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
HWY_DASSERT(interleaved_start % num_queries == 0);
|
static constexpr size_t kQStride = Activations::QStride<TConfig>();
|
||||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
constexpr size_t kQStride = Activations::QStride<TConfig>();
|
static constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||||
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
|
||||||
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
|
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
||||||
constexpr size_t kHeads = TConfig::kHeads;
|
|
||||||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
|
||||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
|
||||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
|
||||||
|
|
||||||
HWY_ASSERT(num_queries <= kv_caches.size());
|
// The attention window usually starts at 0 unless unless `pos` is larger than
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
// the attention window size, then it is `pos` - window_size + 1.
|
||||||
|
static HWY_INLINE size_t StartPos(size_t pos, size_t layer) {
|
||||||
|
const size_t att_window_size = TConfig::kAttentionWindowSizes[layer];
|
||||||
|
return pos - std::min(att_window_size - 1, pos);
|
||||||
|
}
|
||||||
|
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
template <typename T>
|
||||||
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
HWY_INLINE void PositionalEncodingQK(const T* qk, size_t pos, size_t layer,
|
||||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
const float mul, T* qk_out) {
|
||||||
const size_t batch_start = interleaved_start / num_queries;
|
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
// PostQKType::Rope
|
||||||
|
(void)layer;
|
||||||
// For the computation of Q, K, and V, it is useful to remember that
|
if (TConfig::kUseHalfRope) {
|
||||||
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
hwy::CopyBytes(qk, qk_out, kQKVDim * sizeof(*qk));
|
||||||
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
|
||||||
//
|
MulByConst(mul, qk_out, kQKVDim);
|
||||||
// Compute Q only or QKV (if MHA).
|
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
|
||||||
MatMul_4x4</*kAdd=*/false>(
|
|
||||||
num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
|
|
||||||
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim),
|
|
||||||
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
|
|
||||||
MakeMat(activations.q.All(), kHeads * kQStride), pool);
|
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
|
||||||
if constexpr (!kIsMHA) {
|
|
||||||
// Single query and no wraparound means we can use a matmul and write
|
|
||||||
// directly into the KV cache with a stride of kCachePosSize.
|
|
||||||
if (num_queries == 1 &&
|
|
||||||
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
|
|
||||||
const size_t kv_ofs =
|
|
||||||
batch_start * kCachePosSize + layer * kCacheLayerSize;
|
|
||||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
|
||||||
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
|
|
||||||
MatMul_4x4</*kAdd=*/false>(
|
|
||||||
num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
|
|
||||||
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim,
|
|
||||||
kHeads * kQKVDim * kModelDim),
|
|
||||||
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
|
|
||||||
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool);
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
|
||||||
++interleaved_idx) {
|
}
|
||||||
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
|
}
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
|
||||||
const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx);
|
// KV directly to KVCache.
|
||||||
const size_t kv_offset =
|
HWY_NOINLINE void ComputeQKV(const size_t batch_start,
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
const size_t num_interleaved) {
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
PROFILER_ZONE("Gen.Attention.QKV");
|
||||||
|
// For the computation of Q, K, and V, it is useful to remember that
|
||||||
|
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
||||||
|
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||||
|
|
||||||
|
const auto pre_att_rms_out =
|
||||||
|
MakeMat(activations_.pre_att_rms_out.All(), kModelDim);
|
||||||
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
|
num_interleaved, pre_att_rms_out,
|
||||||
|
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
|
||||||
|
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||||
|
MakeMat(activations_.q.All(), kHeads * kQStride), pool_);
|
||||||
|
|
||||||
|
if constexpr (kIsMHA) {
|
||||||
|
static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved");
|
||||||
|
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||||
|
} else {
|
||||||
|
// Single query and no wraparound means we can use a matmul and write
|
||||||
|
// directly into the KV cache with a stride of kCachePosSize.
|
||||||
|
if (num_queries_ == 1 &&
|
||||||
|
batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||||
|
const size_t kv_ofs =
|
||||||
|
batch_start * kCachePosSize + layer_ * kCacheLayerSize;
|
||||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
activations.even_odd.All(), kv, pool);
|
num_tokens_, pre_att_rms_out,
|
||||||
|
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
|
||||||
|
kHeads * kQKVDim * kModelDim),
|
||||||
|
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||||
|
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_);
|
||||||
|
} else {
|
||||||
|
// Proceed row by row because there will be wraparound.
|
||||||
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
|
++interleaved_idx) {
|
||||||
|
const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
|
||||||
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
const size_t cache_pos =
|
||||||
|
div_seq_len_.Remainder(batch_start + batch_idx);
|
||||||
|
const size_t kv_offset =
|
||||||
|
cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
|
||||||
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
|
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||||
|
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
|
activations_.even_odd.All(), kv, pool_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
|
pool_.Run(
|
||||||
|
0, kKVHeads * num_interleaved,
|
||||||
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t head = task % kKVHeads;
|
||||||
|
const size_t interleaved_idx = task / kKVHeads;
|
||||||
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
|
layer_ * kCacheLayerSize +
|
||||||
|
head * kQKVDim * 2;
|
||||||
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
const float* HWY_RESTRICT mha_kv =
|
||||||
|
activations_.q.Batch(interleaved_idx) + head * kQStride + kQKVDim;
|
||||||
|
|
||||||
|
// Copy from `q` if MHA, or apply in-place.
|
||||||
|
PositionalEncodingQK(kIsMHA ? mha_kv : kv, pos, layer_, 1.0f, kv);
|
||||||
|
|
||||||
|
// If MHA, also copy V into KVCache.
|
||||||
|
if (kIsMHA) {
|
||||||
|
hwy::CopyBytes(mha_kv + kQKVDim, kv + kQKVDim,
|
||||||
|
kQKVDim * sizeof(*kv));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
|
||||||
|
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
|
||||||
|
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||||
|
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||||
|
if (HWY_LIKELY(pos <= kSeqLen)) {
|
||||||
|
// Slightly faster: no wraparound.
|
||||||
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
const size_t kv_offset =
|
||||||
|
pos2 * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
|
||||||
|
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||||
|
const float score = Dot(q, k, kQKVDim);
|
||||||
|
head_att[pos2] = score;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
const size_t cache_pos = div_seq_len_.Remainder(pos2);
|
||||||
|
const size_t kv_offset =
|
||||||
|
cache_pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
|
||||||
|
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||||
|
const float score = Dot(q, k, kQKVDim);
|
||||||
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
|
||||||
pool.Run(
|
// `att_out`. Equivalent in gemma/modules.py:
|
||||||
0, kKVHeads * num_interleaved,
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t pos,
|
||||||
const size_t head = task % kKVHeads;
|
const float* HWY_RESTRICT head_att,
|
||||||
const size_t interleaved_idx = task / kKVHeads;
|
const size_t layer,
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t head_offset,
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const hwy::Divisor& div_seq_len,
|
||||||
const size_t pos = batch_start + batch_idx;
|
const KVCache& kv_cache,
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
float* HWY_RESTRICT att_out) {
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
|
||||||
if constexpr (kIsMHA) {
|
|
||||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
|
||||||
const float* HWY_RESTRICT q =
|
|
||||||
activations.q.Batch(interleaved_idx) + head * kQStride;
|
|
||||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
|
||||||
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
|
|
||||||
}
|
|
||||||
PostQK<TConfig>(kv, pos, layer);
|
|
||||||
});
|
|
||||||
|
|
||||||
// A "head group" in the context of GQA refers to a collection of query heads
|
if (HWY_LIKELY(pos <= kSeqLen)) {
|
||||||
// that share the same key and value heads.
|
// Slightly faster: no wraparound.
|
||||||
static_assert((kHeads % kKVHeads) == 0,
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
"query heads must be a multiple of key-value heads");
|
const size_t kv_offset =
|
||||||
constexpr size_t kHeadGroups = kHeads / kKVHeads;
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
const float* HWY_RESTRICT v =
|
||||||
pool.Run(
|
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
0, kHeads * num_interleaved,
|
MulByConstAndAdd(head_att[pos2], v, att_out, kQKVDim);
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
}
|
||||||
const size_t head = task % kHeads;
|
} else {
|
||||||
const size_t interleaved_idx = task / kHeads;
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t kv_offset =
|
||||||
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
const float* HWY_RESTRICT v =
|
||||||
float* HWY_RESTRICT q =
|
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
activations.q.Batch(interleaved_idx) + head * kQStride;
|
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
|
||||||
|
}
|
||||||
// Apply rope and scaling to Q.
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
|
||||||
PostQK<TConfig>(q, pos, layer);
|
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
|
||||||
|
|
||||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
|
||||||
float* HWY_RESTRICT head_att =
|
|
||||||
activations.att.Batch(interleaved_idx) + head * kSeqLen;
|
|
||||||
// Usually start_pos is 0, unless pos is larger than the attention
|
|
||||||
// window size, then it is pos - window_size + 1.
|
|
||||||
const size_t start_pos =
|
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
|
||||||
const size_t kv_offset =
|
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
|
||||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
|
||||||
const float score = Dot(q, k, kQKVDim);
|
|
||||||
head_att[pos2 % kSeqLen] = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in
|
|
||||||
// head_att.
|
|
||||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
|
||||||
if constexpr (TConfig::kAttCap > 0.0f) {
|
|
||||||
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
|
||||||
}
|
|
||||||
Softmax(head_att, head_att_len);
|
|
||||||
|
|
||||||
// Summation of v (kv_cache) weighted by probs (head_att)
|
|
||||||
// into "encoded" (att_out). Compare gemma/modules.py:
|
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
|
||||||
float* HWY_RESTRICT att_out =
|
|
||||||
activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
|
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
|
||||||
const size_t kv_offset =
|
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
|
||||||
float* HWY_RESTRICT v =
|
|
||||||
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
|
||||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
|
|
||||||
// into output (layer_out). Compare gemma/modules.py:
|
|
||||||
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
|
||||||
++interleaved_idx) {
|
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
|
||||||
// rearranging the weights.
|
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx);
|
|
||||||
float* HWY_RESTRICT layer_out =
|
|
||||||
activations.att_post2.Batch(interleaved_idx);
|
|
||||||
// Head 0 (and potentially biases) -> layer_out.
|
|
||||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
|
||||||
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
|
|
||||||
const float* bias =
|
|
||||||
kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr;
|
|
||||||
MatVecT<kAdd, kModelDim, kQKVDim>(
|
|
||||||
layer_weights->attn_vec_einsum_w, 0, att_out, bias,
|
|
||||||
activations.even_odd.All(), layer_out, pool);
|
|
||||||
// Head 1 and following are added to layer_out.
|
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
|
||||||
// NOTE: this is a single kModelDim temp output. If parallelized or using
|
|
||||||
// MatMul, add per-thread storage.
|
|
||||||
float* HWY_RESTRICT head_out = activations.att_post1.All();
|
|
||||||
// TODO: requires MatMul support for offsets.
|
|
||||||
MatVec<kModelDim, kQKVDim>(
|
|
||||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
|
||||||
att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool);
|
|
||||||
AddFrom(head_out, layer_out, kModelDim);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start,
|
||||||
|
const size_t num_interleaved) {
|
||||||
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||||
|
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||||
|
|
||||||
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
|
// heads that share the same key and value heads.
|
||||||
|
static_assert((kHeads % kKVHeads) == 0,
|
||||||
|
"query heads must be a multiple of key-value heads");
|
||||||
|
constexpr size_t kHeadGroups = kHeads / kKVHeads;
|
||||||
|
|
||||||
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
|
pool_.Run(0, kHeads * num_interleaved,
|
||||||
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t head = task % kHeads;
|
||||||
|
const size_t interleaved_idx = task / kHeads;
|
||||||
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
|
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
||||||
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations_.q.Batch(interleaved_idx) + head * kQStride;
|
||||||
|
|
||||||
|
// Apply rope and scaling to Q.
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
|
||||||
|
|
||||||
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
|
|
||||||
|
float* HWY_RESTRICT head_att =
|
||||||
|
activations_.att.Batch(interleaved_idx) + head * kSeqLen;
|
||||||
|
QDotK(start_pos, pos, head_offset, q, kv_cache, head_att);
|
||||||
|
// SoftMax with optional SoftCap yields "probabilities" in
|
||||||
|
// head_att.
|
||||||
|
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||||
|
MaybeLogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||||
|
Softmax(head_att, head_att_len);
|
||||||
|
|
||||||
|
float* HWY_RESTRICT att_out =
|
||||||
|
activations_.att_out.Batch(interleaved_idx) +
|
||||||
|
head * kQKVDim;
|
||||||
|
WeightedSumV(start_pos, pos, head_att, layer_, head_offset,
|
||||||
|
div_seq_len_, kv_cache, att_out);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sums encoded (`att_out`) over num_heads and head_dim (kQKVDim) into output
|
||||||
|
// (`layer_out`). Compare gemma/modules.py:
|
||||||
|
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
||||||
|
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||||
|
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||||
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
|
++interleaved_idx) {
|
||||||
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
|
// rearranging the weights.
|
||||||
|
float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx);
|
||||||
|
float* HWY_RESTRICT layer_out =
|
||||||
|
activations_.att_post2.Batch(interleaved_idx);
|
||||||
|
// Head 0 (and potentially biases) -> layer_out.
|
||||||
|
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||||
|
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
|
||||||
|
const float* bias =
|
||||||
|
kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr;
|
||||||
|
MatVecT<kAdd, kModelDim, kQKVDim>(
|
||||||
|
layer_weights_.attn_vec_einsum_w, 0, att_out, bias,
|
||||||
|
activations_.even_odd.All(), layer_out, pool_);
|
||||||
|
// Head 1 and following are added to layer_out.
|
||||||
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
|
// NOTE: this is a single kModelDim temp output. If parallelized or
|
||||||
|
// using MatMul, add per-thread storage.
|
||||||
|
float* HWY_RESTRICT head_out = activations_.att_post1.All();
|
||||||
|
// TODO: requires MatMul support for offsets.
|
||||||
|
MatVec<kModelDim, kQKVDim>(
|
||||||
|
layer_weights_.attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||||
|
att_out + head * kQKVDim, activations_.even_odd.All(), head_out,
|
||||||
|
pool_);
|
||||||
|
AddFrom(head_out, layer_out, kModelDim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
|
size_t num_queries, size_t layer, Activations& activations,
|
||||||
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||||
|
hwy::ThreadPool& pool)
|
||||||
|
: interleaved_start_(interleaved_start),
|
||||||
|
num_tokens_(num_tokens),
|
||||||
|
num_queries_(num_queries),
|
||||||
|
layer_(layer),
|
||||||
|
activations_(activations),
|
||||||
|
layer_weights_(*layer_weights),
|
||||||
|
div_seq_len_(div_seq_len),
|
||||||
|
kv_caches_(kv_caches),
|
||||||
|
pool_(pool) {
|
||||||
|
HWY_DASSERT(interleaved_start_ % num_queries_ == 0);
|
||||||
|
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_INLINE void operator()() {
|
||||||
|
const size_t batch_start = interleaved_start_ / num_queries_;
|
||||||
|
const size_t num_interleaved = num_tokens_ * num_queries_;
|
||||||
|
|
||||||
|
ComputeQKV(batch_start, num_interleaved);
|
||||||
|
DotSoftmaxWeightedSum(batch_start, num_interleaved);
|
||||||
|
SumHeads(num_interleaved);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const size_t interleaved_start_;
|
||||||
|
const size_t num_tokens_;
|
||||||
|
const size_t num_queries_;
|
||||||
|
const size_t layer_;
|
||||||
|
Activations& activations_;
|
||||||
|
const CompressedLayer<TConfig>& layer_weights_;
|
||||||
|
const hwy::Divisor& div_seq_len_;
|
||||||
|
const KVCaches& kv_caches_;
|
||||||
|
hwy::ThreadPool& pool_;
|
||||||
|
};
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
||||||
size_t num_tokens, size_t num_queries, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
|
const hwy::Divisor& div_seq_len,
|
||||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
||||||
activations, layer_weights, kv_caches, pool);
|
activations, layer_weights, div_seq_len, kv_caches,
|
||||||
|
pool)();
|
||||||
} else {
|
} else {
|
||||||
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||||
// this code for non-Griffin models.
|
// this code for non-Griffin models.
|
||||||
|
|
@ -421,6 +522,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
||||||
template <class TConfig, typename T>
|
template <class TConfig, typename T>
|
||||||
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||||
size_t count) {
|
size_t count) {
|
||||||
|
PROFILER_ZONE("Gen.Activation");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<T>;
|
using DF = hn::ScalableTag<T>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -516,7 +618,8 @@ template <class TConfig>
|
||||||
HWY_NOINLINE void TransformerLayer(
|
HWY_NOINLINE void TransformerLayer(
|
||||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
|
|
@ -528,7 +631,7 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
activations.pre_att_rms_out.All(), kModelDim);
|
activations.pre_att_rms_out.All(), kModelDim);
|
||||||
|
|
||||||
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
|
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
|
||||||
activations, layer_weights, kv_caches, pool);
|
activations, layer_weights, div_seq_len, kv_caches, pool);
|
||||||
|
|
||||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
||||||
activations.att_post2.All());
|
activations.att_post2.All());
|
||||||
|
|
@ -606,6 +709,7 @@ class PrefillState {
|
||||||
const size_t query_idx_start,
|
const size_t query_idx_start,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
|
const hwy::Divisor& div_seq_len,
|
||||||
const KVCaches& kv_caches, PerClusterPools& pools) {
|
const KVCaches& kv_caches, PerClusterPools& pools) {
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
const size_t num_queries = prompts.size();
|
const size_t num_queries = prompts.size();
|
||||||
|
|
@ -638,9 +742,10 @@ class PrefillState {
|
||||||
// Transformer with one batch of tokens from a single query.
|
// Transformer with one batch of tokens from a single query.
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig>(
|
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
|
||||||
tbatch_size, kPrefillQueries, pos + tbatch_start, layer,
|
pos + tbatch_start, layer,
|
||||||
layer_weights, activations, prefill_kv_caches, inner_pool);
|
layer_weights, activations, div_seq_len,
|
||||||
|
prefill_kv_caches, inner_pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
|
|
@ -664,6 +769,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
size_t num_queries, size_t pos,
|
size_t num_queries, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
|
const hwy::Divisor& div_seq_len,
|
||||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||||
const LayersOutputFunc& layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
|
|
@ -684,7 +790,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
|
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
|
||||||
layer_weights, activations, kv_caches, pool);
|
layer_weights, activations, div_seq_len,
|
||||||
|
kv_caches, pool);
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
|
|
@ -822,6 +929,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||||
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||||
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
|
|
||||||
size_t min_prompt_size, max_prompt_size;
|
size_t min_prompt_size, max_prompt_size;
|
||||||
const std::vector<int> prompt = InterleaveQueries(
|
const std::vector<int> prompt = InterleaveQueries(
|
||||||
|
|
@ -857,7 +965,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
pools);
|
pools);
|
||||||
prefill_start = hwy::platform::Now();
|
prefill_start = hwy::platform::Now();
|
||||||
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
||||||
weights, runtime_config, kv_caches, pools);
|
weights, runtime_config, div_seq_len, kv_caches,
|
||||||
|
pools);
|
||||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -881,8 +990,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
++gen_per_query) {
|
++gen_per_query) {
|
||||||
// Decode: generate one token for each query.
|
// Decode: generate one token for each query.
|
||||||
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
||||||
interleaved_pos, weights, activations, kv_caches, pool,
|
interleaved_pos, weights, activations, div_seq_len,
|
||||||
runtime_config.layers_output);
|
kv_caches, pool, runtime_config.layers_output);
|
||||||
interleaved_pos += num_queries;
|
interleaved_pos += num_queries;
|
||||||
|
|
||||||
bool all_queries_eos = true;
|
bool all_queries_eos = true;
|
||||||
|
|
@ -895,9 +1004,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
MakeMat(activations.logits.All(), kVocabSize), pool);
|
MakeMat(activations.logits.All(), kVocabSize), pool);
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||||
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
|
||||||
}
|
|
||||||
Softmax(logits, kVocabSize);
|
Softmax(logits, kVocabSize);
|
||||||
const int token = sample_token(logits, kVocabSize);
|
const int token = sample_token(logits, kVocabSize);
|
||||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||||
|
|
|
||||||
|
|
@ -381,16 +381,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
of this rotation matrix which is simply the same matrix with -pos parameter)
|
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||||
size_t dim_qkv, int pos) {
|
// This overload is called from backprop/ and if kUseHalfRope.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
|
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||||
|
PROFILER_FUNC;
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
const float freq_exponents =
|
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
||||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
|
||||||
const float theta = StaticCast<float>(pos) / timescale;
|
|
||||||
const float cos_val = cosf(theta);
|
const float cos_val = cosf(theta);
|
||||||
const float sin_val = sinf(theta);
|
const float sin_val = sinf(theta);
|
||||||
const float x0 = x[dim];
|
const float x0 = x[dim];
|
||||||
|
|
@ -400,24 +400,23 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
// TODO(janwas): vectorize
|
||||||
float* HWY_RESTRICT x,
|
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||||
size_t dim_qkv,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
int pos) {
|
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
const float* HWY_RESTRICT inv_timescale, int pos,
|
||||||
|
float* HWY_RESTRICT x_out) {
|
||||||
|
PROFILER_FUNC;
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
const float freq_exponents =
|
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
||||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
|
||||||
const float theta = StaticCast<float>(pos) / timescale;
|
|
||||||
const float cos_val = cosf(theta);
|
const float cos_val = cosf(theta);
|
||||||
const float sin_val = sinf(theta);
|
const float sin_val = sinf(theta);
|
||||||
const float x0 = x[dim];
|
const float x0 = x[dim];
|
||||||
const float x1 = x[dim + half_dim_qkv];
|
const float x1 = x[dim + half_dim_qkv];
|
||||||
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||||
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -577,12 +576,19 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
float* HWY_RESTRICT x,
|
const size_t size) {
|
||||||
const size_t size) {
|
|
||||||
LogitsSoftCap(cap, x, size, size);
|
LogitsSoftCap(cap, x, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calls LogitsSoftCap if cap != 0.0f.
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
|
const float cap, float* HWY_RESTRICT x, const size_t size) {
|
||||||
|
if (cap != 0.0f) {
|
||||||
|
LogitsSoftCap(cap, x, size, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||||
size_t max_index = 0;
|
size_t max_index = 0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue