mirror of https://github.com/google/gemma.cpp.git
Further nuq_test speedups to prevent timeout
PiperOrigin-RevId: 670863385
This commit is contained in:
parent
9661b81c4b
commit
07c34cb18a
|
|
@ -116,6 +116,7 @@ cc_library(
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "nuq_test",
|
name = "nuq_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
timeout = "long",
|
||||||
srcs = ["nuq_test.cc"],
|
srcs = ["nuq_test.cc"],
|
||||||
features = ["fully_static_link"],
|
features = ["fully_static_link"],
|
||||||
linkstatic = True,
|
linkstatic = True,
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,9 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
|
|
@ -49,6 +49,8 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
static constexpr size_t kTimingReps = hn::AdjustedReps(3);
|
||||||
|
|
||||||
// All-equal inputs: only one cluster
|
// All-equal inputs: only one cluster
|
||||||
struct TestFlat {
|
struct TestFlat {
|
||||||
template <typename T, class DF>
|
template <typename T, class DF>
|
||||||
|
|
@ -201,7 +203,7 @@ struct TestNormal {
|
||||||
float centers[kClusters];
|
float centers[kClusters];
|
||||||
uint16_t indices[kGroupSize];
|
uint16_t indices[kGroupSize];
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||||
df, in.get(), kGroupSize, buf, centers, indices);
|
df, in.get(), kGroupSize, buf, centers, indices);
|
||||||
|
|
@ -327,7 +329,7 @@ struct TestStream {
|
||||||
|
|
||||||
ClusterBuf buf;
|
ClusterBuf buf;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
const size_t unused_clusters =
|
const size_t unused_clusters =
|
||||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||||
|
|
@ -339,7 +341,7 @@ struct TestStream {
|
||||||
num * sizeof(float) * 1E-6 / elapsed);
|
num * sizeof(float) * 1E-6 / elapsed);
|
||||||
|
|
||||||
elapsed = hwy::HighestValue<double>();
|
elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||||
const double t1 = hwy::platform::Now();
|
const double t1 = hwy::platform::Now();
|
||||||
|
|
@ -408,7 +410,7 @@ struct TestDot {
|
||||||
// Compute dot product without decompression.
|
// Compute dot product without decompression.
|
||||||
float actual = 0.0f;
|
float actual = 0.0f;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue