mirror of https://github.com/google/gemma.cpp.git
Fix bf16 KV recompression and Rope(), fixes #608
Also add more helpful error message for prompt > seq_len Also update ops_test, adding coverage for Rope(). PiperOrigin-RevId: 772945644
This commit is contained in:
parent
88284387db
commit
7f62c2606e
|
|
@ -299,19 +299,21 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
layer_idx * cache_layer_size +
|
layer_idx * cache_layer_size +
|
||||||
head * qkv_dim * 2;
|
head * qkv_dim * 2;
|
||||||
|
|
||||||
|
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32,
|
||||||
|
2 * qkv_dim);
|
||||||
|
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim);
|
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_ALIGN float kv_f32[kMaxQKVDim];
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -480,7 +480,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
// We use a single divisor, so all sequence lengths must be the same.
|
// We use a single divisor, so all sequence lengths must be the same.
|
||||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||||
}
|
}
|
||||||
HWY_ASSERT(max_prompt_size < seq_len);
|
if (max_prompt_size >= seq_len) {
|
||||||
|
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
|
||||||
|
max_prompt_size);
|
||||||
|
}
|
||||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||||
|
|
||||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||||
|
|
|
||||||
|
|
@ -406,7 +406,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||||
PROFILER_ZONE("ops.Rope");
|
PROFILER_ZONE("ops.Rope");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
|
@ -430,13 +430,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations.
|
// Scale input with rotations.
|
||||||
VF vx0 = hn::LoadU(df, x + dim);
|
const VF vx0 = hn::LoadU(df, x + dim);
|
||||||
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
const VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
||||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
hn::StoreU(vx0, df, x + dim);
|
hn::StoreU(vout0, df, x + dim);
|
||||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||||
|
|
@ -452,19 +452,19 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations.
|
// Scale input with rotations.
|
||||||
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
const VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
||||||
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
const VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
hn::StoreN(vout0, df, x + dim, remaining_dims);
|
||||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
|
@ -489,13 +489,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations and multiply with constant.
|
// Scale input with rotations and multiply with constant.
|
||||||
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
const VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
||||||
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
const VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
||||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
hn::StoreU(vx0, df, x + dim);
|
hn::StoreU(vout0, df, x + dim);
|
||||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||||
|
|
@ -511,14 +511,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations and multiply with constant.
|
// Scale input with rotations and multiply with constant.
|
||||||
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
const VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
||||||
VF vx1 =
|
const VF vx1 =
|
||||||
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
|
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
|
||||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
hn::StoreN(vout0, df, x + dim, remaining_dims);
|
||||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
125
ops/ops_test.cc
125
ops/ops_test.cc
|
|
@ -83,48 +83,44 @@ T Random(hwy::RandomState& rng) {
|
||||||
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
|
HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x, size_t size) {
|
float* HWY_RESTRICT x, size_t size) {
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
x[i] += other[i];
|
x[i] += other[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
|
HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x, size_t size,
|
float* HWY_RESTRICT x, size_t size) {
|
||||||
size_t max_pos) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
|
||||||
x[i] *= other[i];
|
x[i] *= other[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x,
|
||||||
size_t max_pos) {
|
size_t size) {
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
x[i] *= c;
|
x[i] *= c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
out[i] += x[i] * c;
|
out[i] += x[i] * c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) {
|
||||||
size_t mask_pos) {
|
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
|
||||||
float sum = 0.0;
|
float sum = 0.0;
|
||||||
const float maxval = *std::max_element(x, x + mask_pos);
|
const float maxval = *std::max_element(x, x + size);
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
x[i] = std::exp(x[i] - maxval);
|
x[i] = std::exp(x[i] - maxval);
|
||||||
sum += x[i];
|
sum += x[i];
|
||||||
}
|
}
|
||||||
const float scale = 1.0f / sum;
|
const float scale = 1.0f / sum;
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
x[i] *= scale;
|
x[i] *= scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -169,7 +165,7 @@ struct TestAddFrom {
|
||||||
o[i] = Random<T>(rng);
|
o[i] = Random<T>(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
SourceAddFrom(o, e, count);
|
SimpleAddFrom(o, e, count);
|
||||||
AddFrom(o, x, count);
|
AddFrom(o, x, count);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
|
@ -177,38 +173,6 @@ struct TestAddFrom {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TestMulBy {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
||||||
HWY_ASSERT(px && pe && po);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
T* o = po.get() + misalign_b;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
o[i] = Random<T>(rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
SourceMulBy(o, e, count, count);
|
|
||||||
MulBy(o, x, count, count);
|
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestMulByConstAndAdd {
|
struct TestMulByConstAndAdd {
|
||||||
template <class D>
|
template <class D>
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
|
@ -234,7 +198,7 @@ struct TestMulByConstAndAdd {
|
||||||
}
|
}
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SourceMulByConstAndAdd(constant, o, e, count);
|
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||||
MulByConstAndAdd(constant, o, x, count);
|
MulByConstAndAdd(constant, o, x, count);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
|
@ -264,8 +228,8 @@ struct TestMulByConst {
|
||||||
}
|
}
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SourceMulByConst(constant, e, count, count);
|
SimpleMulByConst(constant, e, count);
|
||||||
MulByConst(constant, x, count, count);
|
MulByConst(constant, x, count);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
|
|
@ -294,8 +258,8 @@ struct TestSoftmax {
|
||||||
e[i] = x[i];
|
e[i] = x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
SourceSoftmax(e, count, count);
|
SimpleSoftmax(e, count);
|
||||||
Softmax(x, count, count);
|
Softmax(x, count);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
|
@ -331,10 +295,6 @@ void TestAllAddFrom() {
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllMulBy() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllMulByConst() {
|
void TestAllMulByConst() {
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
||||||
}
|
}
|
||||||
|
|
@ -371,8 +331,8 @@ void TestSigmoid() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
|
|
@ -387,9 +347,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||||
ChooseWrapping(Model::GEMMA2_9B));
|
ChooseWrapping(Model::GEMMA2_9B));
|
||||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
MatStorageT<float> x("x", dim_qkv);
|
MatStorageT<float> x("x", dim_qkv);
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
|
|
@ -397,44 +357,58 @@ void TestRopeAndMulBy() {
|
||||||
std::normal_distribution<float> r{0.0, 5.0};
|
std::normal_distribution<float> r{0.0, 5.0};
|
||||||
auto random_float = [&r, &gen] { return r(gen); };
|
auto random_float = [&r, &gen] { return r(gen); };
|
||||||
|
|
||||||
for (int i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
x.Row(0)[i] = random_float();
|
x.Row(0)[i] = random_float();
|
||||||
}
|
}
|
||||||
|
|
||||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||||
const float kmul = 1.0;
|
constexpr float kmul = 1.0f;
|
||||||
|
|
||||||
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
||||||
MatStorageT<float> qactual("qactual", dim_qkv);
|
MatStorageT<float> qactual("qactual", dim_qkv);
|
||||||
MatStorageT<float> kexpected("kexpected", dim_qkv);
|
MatStorageT<float> kexpected("kexpected", dim_qkv);
|
||||||
MatStorageT<float> kactual("kactual", dim_qkv);
|
MatStorageT<float> kactual("kactual", dim_qkv);
|
||||||
|
MatStorageT<float> kactual2("kactual2", dim_qkv);
|
||||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||||
config.layer_configs[0].qkv_dim,
|
config.layer_configs[0].qkv_dim,
|
||||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||||
for (int pos = 1; pos < 500; pos++) {
|
for (size_t pos = 1; pos < 500; pos++) {
|
||||||
// Rope'd Q embeddings
|
// Rope'd Q embeddings with query scale
|
||||||
CopyMat(x, qactual);
|
|
||||||
CopyMat(x, qexpected);
|
CopyMat(x, qexpected);
|
||||||
|
CopyMat(x, qactual);
|
||||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
pos);
|
pos);
|
||||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||||
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
|
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < dim_qkv; ++i) {
|
// Same without query scale
|
||||||
EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4)
|
CopyMat(x, qexpected);
|
||||||
<< "qIndex:" << i << "qInput:" << qactual.Row(0)[i];
|
CopyMat(x, qactual);
|
||||||
|
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
|
pos);
|
||||||
|
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||||
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
|
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rope'd K embeddings
|
// Rope'd K embeddings
|
||||||
CopyMat(x, kactual);
|
|
||||||
CopyMat(x, kexpected);
|
CopyMat(x, kexpected);
|
||||||
|
CopyMat(x, kactual);
|
||||||
|
CopyMat(x, kactual2);
|
||||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
pos);
|
pos);
|
||||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||||
|
static_assert(kmul == 1.0f, "");
|
||||||
|
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||||
|
|
||||||
for (int i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4)
|
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
<< "kIndex:" << i << "kInput:" << kactual.Row(0)[i];
|
}
|
||||||
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
|
EXPECT_NEAR(kexpected.Row(0)[i], kactual2.Row(0)[i], 1e-4) << " " << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -662,7 +636,6 @@ HWY_AFTER_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
HWY_BEFORE_TEST(OpsTest);
|
HWY_BEFORE_TEST(OpsTest);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue