mirror of https://github.com/google/gemma.cpp.git
Fix bf16 KV recompression and Rope(), fixes #608
Also add more helpful error message for prompt > seq_len Also update ops_test, adding coverage for Rope(). PiperOrigin-RevId: 772945644
This commit is contained in:
parent
88284387db
commit
7f62c2606e
|
|
@ -299,19 +299,21 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
|
||||
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32,
|
||||
2 * qkv_dim);
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim);
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
|
||||
});
|
||||
}
|
||||
|
||||
HWY_ALIGN float kv_f32[kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -480,7 +480,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||
}
|
||||
HWY_ASSERT(max_prompt_size < seq_len);
|
||||
if (max_prompt_size >= seq_len) {
|
||||
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
|
||||
max_prompt_size);
|
||||
}
|
||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
|
|
|
|||
|
|
@ -406,7 +406,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
PROFILER_ZONE("ops.Rope");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
|
|
@ -430,13 +430,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations.
|
||||
VF vx0 = hn::LoadU(df, x + dim);
|
||||
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
const VF vx0 = hn::LoadU(df, x + dim);
|
||||
const VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
||||
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreU(vx0, df, x + dim);
|
||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||
hn::StoreU(vout0, df, x + dim);
|
||||
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||
|
|
@ -452,19 +452,19 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations.
|
||||
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
||||
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
const VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
||||
const VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
||||
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
hn::StoreN(vout0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
|
|
@ -489,13 +489,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
||||
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
const VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
||||
const VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
||||
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreU(vx0, df, x + dim);
|
||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||
hn::StoreU(vout0, df, x + dim);
|
||||
hn::StoreU(vout1, df, x + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||
|
|
@ -511,14 +511,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
||||
VF vx1 =
|
||||
const VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
||||
const VF vx1 =
|
||||
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
hn::StoreN(vout0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
125
ops/ops_test.cc
125
ops/ops_test.cc
|
|
@ -83,48 +83,44 @@ T Random(hwy::RandomState& rng) {
|
|||
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
||||
}
|
||||
|
||||
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
|
||||
HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, size_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] += other[i];
|
||||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, size_t size,
|
||||
size_t max_pos) {
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
for (size_t i = 0; i < max_pos; ++i) {
|
||||
HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, size_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] *= other[i];
|
||||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
||||
size_t max_pos) {
|
||||
for (size_t i = 0; i < max_pos; ++i) {
|
||||
HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
||||
HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out, size_t size) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
out[i] += x[i] * c;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
||||
size_t mask_pos) {
|
||||
HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) {
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
float sum = 0.0;
|
||||
const float maxval = *std::max_element(x, x + mask_pos);
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
const float maxval = *std::max_element(x, x + size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] = std::exp(x[i] - maxval);
|
||||
sum += x[i];
|
||||
}
|
||||
const float scale = 1.0f / sum;
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
}
|
||||
|
|
@ -169,7 +165,7 @@ struct TestAddFrom {
|
|||
o[i] = Random<T>(rng);
|
||||
}
|
||||
|
||||
SourceAddFrom(o, e, count);
|
||||
SimpleAddFrom(o, e, count);
|
||||
AddFrom(o, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
|
|
@ -177,38 +173,6 @@ struct TestAddFrom {
|
|||
}
|
||||
};
|
||||
|
||||
struct TestMulBy {
|
||||
template <class D>
|
||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||
hwy::RandomState& rng) {
|
||||
using T = hn::TFromD<D>;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||
hwy::AlignedFreeUniquePtr<T[]> po =
|
||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
||||
HWY_ASSERT(px && pe && po);
|
||||
|
||||
T* x = px.get() + misalign_a;
|
||||
T* e = pe.get() + misalign_a;
|
||||
T* o = po.get() + misalign_b;
|
||||
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
x[i] = Random<T>(rng);
|
||||
e[i] = x[i];
|
||||
o[i] = Random<T>(rng);
|
||||
}
|
||||
|
||||
SourceMulBy(o, e, count, count);
|
||||
MulBy(o, x, count, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
};
|
||||
|
||||
struct TestMulByConstAndAdd {
|
||||
template <class D>
|
||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||
|
|
@ -234,7 +198,7 @@ struct TestMulByConstAndAdd {
|
|||
}
|
||||
T constant = Random<T>(rng);
|
||||
|
||||
SourceMulByConstAndAdd(constant, o, e, count);
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
MulByConstAndAdd(constant, o, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
|
|
@ -264,8 +228,8 @@ struct TestMulByConst {
|
|||
}
|
||||
T constant = Random<T>(rng);
|
||||
|
||||
SourceMulByConst(constant, e, count, count);
|
||||
MulByConst(constant, x, count, count);
|
||||
SimpleMulByConst(constant, e, count);
|
||||
MulByConst(constant, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -294,8 +258,8 @@ struct TestSoftmax {
|
|||
e[i] = x[i];
|
||||
}
|
||||
|
||||
SourceSoftmax(e, count, count);
|
||||
Softmax(x, count, count);
|
||||
SimpleSoftmax(e, count);
|
||||
Softmax(x, count);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -331,10 +295,6 @@ void TestAllAddFrom() {
|
|||
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
||||
}
|
||||
|
||||
void TestAllMulBy() {
|
||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
|
||||
}
|
||||
|
||||
void TestAllMulByConst() {
|
||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
||||
}
|
||||
|
|
@ -371,8 +331,8 @@ void TestSigmoid() {
|
|||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
|
|
@ -387,9 +347,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
}
|
||||
|
||||
void TestRopeAndMulBy() {
|
||||
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
MatStorageT<float> x("x", dim_qkv);
|
||||
|
||||
std::mt19937 gen;
|
||||
|
|
@ -397,44 +357,58 @@ void TestRopeAndMulBy() {
|
|||
std::normal_distribution<float> r{0.0, 5.0};
|
||||
auto random_float = [&r, &gen] { return r(gen); };
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
x.Row(0)[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||
const float kmul = 1.0;
|
||||
constexpr float kmul = 1.0f;
|
||||
|
||||
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
||||
MatStorageT<float> qactual("qactual", dim_qkv);
|
||||
MatStorageT<float> kexpected("kexpected", dim_qkv);
|
||||
MatStorageT<float> kactual("kactual", dim_qkv);
|
||||
MatStorageT<float> kactual2("kactual2", dim_qkv);
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings
|
||||
CopyMat(x, qactual);
|
||||
for (size_t pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings with query scale
|
||||
CopyMat(x, qexpected);
|
||||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4)
|
||||
<< "qIndex:" << i << "qInput:" << qactual.Row(0)[i];
|
||||
// Same without query scale
|
||||
CopyMat(x, qexpected);
|
||||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
||||
// Rope'd K embeddings
|
||||
CopyMat(x, kactual);
|
||||
CopyMat(x, kexpected);
|
||||
CopyMat(x, kactual);
|
||||
CopyMat(x, kactual2);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
static_assert(kmul == 1.0f, "");
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4)
|
||||
<< "kIndex:" << i << "kInput:" << kactual.Row(0)[i];
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual2.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -662,7 +636,6 @@ HWY_AFTER_NAMESPACE();
|
|||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(OpsTest);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
|
|
|
|||
Loading…
Reference in New Issue