Fix bf16 KV recompression and Rope(), fixes #608

Also add more helpful error message for prompt > seq_len

Also update ops_test, adding coverage for Rope().

PiperOrigin-RevId: 772945644
This commit is contained in:
Jan Wassenberg 2025-06-18 09:13:47 -07:00 committed by Copybara-Service
parent 88284387db
commit 7f62c2606e
4 changed files with 86 additions and 108 deletions

View File

@ -299,19 +299,21 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
layer_idx * cache_layer_size +
head * qkv_dim * 2;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32,
2 * qkv_dim);
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim);
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
});
}
HWY_ALIGN float kv_f32[kMaxQKVDim];
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
CompressPerThread tls;
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
}

View File

@ -480,7 +480,10 @@ static void GenerateT(const ModelConfig& config,
// We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
HWY_ASSERT(max_prompt_size < seq_len);
if (max_prompt_size >= seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
max_prompt_size);
}
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have

View File

@ -406,7 +406,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
// This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv,
float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.Rope");
HWY_DASSERT(dim_qkv % 2 == 0);
@ -430,13 +430,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations.
VF vx0 = hn::LoadU(df, x + dim);
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
const VF vx0 = hn::LoadU(df, x + dim);
const VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreU(vx0, df, x + dim);
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
hn::StoreU(vout0, df, x + dim);
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims - same as above, but with LoadN.
@ -452,19 +452,19 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations.
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
const VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
const VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
hn::StoreN(vout0, df, x + dim, remaining_dims);
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
}
}
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
@ -489,13 +489,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
const VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
const VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreU(vx0, df, x + dim);
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
hn::StoreU(vout0, df, x + dim);
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims - same as above, but with LoadN.
@ -511,14 +511,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
VF vx1 =
const VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
const VF vx1 =
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
hn::StoreN(vout0, df, x + dim, remaining_dims);
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
}
}

View File

@ -83,48 +83,44 @@ T Random(hwy::RandomState& rng) {
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
}
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] += other[i];
}
}
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
HWY_DASSERT(max_pos <= size);
for (size_t i = 0; i < max_pos; ++i) {
HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] *= other[i];
}
}
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
size_t max_pos) {
for (size_t i = 0; i < max_pos; ++i) {
HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x,
size_t size) {
for (size_t i = 0; i < size; ++i) {
x[i] *= c;
}
}
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
float* HWY_RESTRICT out, size_t size) {
for (size_t i = 0; i < size; ++i) {
out[i] += x[i] * c;
}
}
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
size_t mask_pos) {
HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) {
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
float sum = 0.0;
const float maxval = *std::max_element(x, x + mask_pos);
for (size_t i = 0; i < mask_pos; ++i) {
const float maxval = *std::max_element(x, x + size);
for (size_t i = 0; i < size; ++i) {
x[i] = std::exp(x[i] - maxval);
sum += x[i];
}
const float scale = 1.0f / sum;
for (size_t i = 0; i < mask_pos; ++i) {
for (size_t i = 0; i < size; ++i) {
x[i] *= scale;
}
}
@ -169,7 +165,7 @@ struct TestAddFrom {
o[i] = Random<T>(rng);
}
SourceAddFrom(o, e, count);
SimpleAddFrom(o, e, count);
AddFrom(o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -177,38 +173,6 @@ struct TestAddFrom {
}
};
struct TestMulBy {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> po =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
HWY_ASSERT(px && pe && po);
T* x = px.get() + misalign_a;
T* e = pe.get() + misalign_a;
T* o = po.get() + misalign_b;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
e[i] = x[i];
o[i] = Random<T>(rng);
}
SourceMulBy(o, e, count, count);
MulBy(o, x, count, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
}
};
struct TestMulByConstAndAdd {
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
@ -234,7 +198,7 @@ struct TestMulByConstAndAdd {
}
T constant = Random<T>(rng);
SourceMulByConstAndAdd(constant, o, e, count);
SimpleMulByConstAndAdd(constant, o, e, count);
MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -264,8 +228,8 @@ struct TestMulByConst {
}
T constant = Random<T>(rng);
SourceMulByConst(constant, e, count, count);
MulByConst(constant, x, count, count);
SimpleMulByConst(constant, e, count);
MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
@ -294,8 +258,8 @@ struct TestSoftmax {
e[i] = x[i];
}
SourceSoftmax(e, count, count);
Softmax(x, count, count);
SimpleSoftmax(e, count);
Softmax(x, count);
T sum = 0.0f;
for (size_t i = 0; i < count; ++i) {
@ -331,10 +295,6 @@ void TestAllAddFrom() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
}
void TestAllMulBy() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
}
void TestAllMulByConst() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
}
@ -371,8 +331,8 @@ void TestSigmoid() {
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
@ -387,9 +347,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
}
void TestRopeAndMulBy() {
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B));
int dim_qkv = config.layer_configs[0].qkv_dim;
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
MatStorageT<float> x("x", dim_qkv);
std::mt19937 gen;
@ -397,44 +357,58 @@ void TestRopeAndMulBy() {
std::normal_distribution<float> r{0.0, 5.0};
auto random_float = [&r, &gen] { return r(gen); };
for (int i = 0; i < dim_qkv; ++i) {
for (size_t i = 0; i < dim_qkv; ++i) {
x.Row(0)[i] = random_float();
}
const float qmul = AttentionActivations::ChooseQueryScale(config);
const float kmul = 1.0;
constexpr float kmul = 1.0f;
MatStorageT<float> qexpected("qexpected", dim_qkv);
MatStorageT<float> qactual("qactual", dim_qkv);
MatStorageT<float> kexpected("kexpected", dim_qkv);
MatStorageT<float> kactual("kactual", dim_qkv);
MatStorageT<float> kactual2("kactual2", dim_qkv);
MatStorageT<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings
CopyMat(x, qactual);
for (size_t pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings with query scale
CopyMat(x, qexpected);
CopyMat(x, qactual);
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
}
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual.Row(0)[i];
// Same without query scale
CopyMat(x, qexpected);
CopyMat(x, qactual);
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
}
// Rope'd K embeddings
CopyMat(x, kactual);
CopyMat(x, kexpected);
CopyMat(x, kactual);
CopyMat(x, kactual2);
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
static_assert(kmul == 1.0f, "");
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual.Row(0)[i];
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
}
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kexpected.Row(0)[i], kactual2.Row(0)[i], 1e-4) << " " << i;
}
}
}
@ -662,7 +636,6 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(OpsTest);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);