Further nuq_test speedups to prevent timeout

PiperOrigin-RevId: 670863385
This commit is contained in:
Jan Wassenberg 2024-09-04 00:49:06 -07:00 committed by Copybara-Service
parent 9661b81c4b
commit 07c34cb18a
2 changed files with 9 additions and 6 deletions

View File

@ -116,6 +116,7 @@ cc_library(
cc_test(
name = "nuq_test",
size = "small",
timeout = "long",
srcs = ["nuq_test.cc"],
features = ["fully_static_link"],
linkstatic = True,

View File

@ -13,9 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
#endif
#include "compression/nuq.h"
@ -49,6 +49,8 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
static constexpr size_t kTimingReps = hn::AdjustedReps(3);
// All-equal inputs: only one cluster
struct TestFlat {
template <typename T, class DF>
@ -201,7 +203,7 @@ struct TestNormal {
float centers[kClusters];
uint16_t indices[kGroupSize];
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
for (size_t rep = 0; rep < kTimingReps; ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters = NuqClustering::ClusterExactL2(
df, in.get(), kGroupSize, buf, centers, indices);
@ -327,7 +329,7 @@ struct TestStream {
ClusterBuf buf;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
for (size_t rep = 0; rep < kTimingReps; ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters =
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
@ -339,7 +341,7 @@ struct TestStream {
num * sizeof(float) * 1E-6 / elapsed);
elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
for (size_t rep = 0; rep < kTimingReps; ++rep) {
const double t0 = hwy::platform::Now();
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
const double t1 = hwy::platform::Now();
@ -408,7 +410,7 @@ struct TestDot {
// Compute dot product without decompression.
float actual = 0.0f;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
for (size_t rep = 0; rep < kTimingReps; ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);