mirror of https://github.com/google/gemma.cpp.git
Further nuq_test speedups to prevent timeout
PiperOrigin-RevId: 670863385
This commit is contained in:
parent
9661b81c4b
commit
07c34cb18a
|
|
@ -116,6 +116,7 @@ cc_library(
|
|||
cc_test(
|
||||
name = "nuq_test",
|
||||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["nuq_test.cc"],
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = True,
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||
#endif
|
||||
|
||||
#include "compression/nuq.h"
|
||||
|
|
@ -49,6 +49,8 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
static constexpr size_t kTimingReps = hn::AdjustedReps(3);
|
||||
|
||||
// All-equal inputs: only one cluster
|
||||
struct TestFlat {
|
||||
template <typename T, class DF>
|
||||
|
|
@ -201,7 +203,7 @@ struct TestNormal {
|
|||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
|
|
@ -327,7 +329,7 @@ struct TestStream {
|
|||
|
||||
ClusterBuf buf;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
|
|
@ -339,7 +341,7 @@ struct TestStream {
|
|||
num * sizeof(float) * 1E-6 / elapsed);
|
||||
|
||||
elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||
const double t1 = hwy::platform::Now();
|
||||
|
|
@ -408,7 +410,7 @@ struct TestDot {
|
|||
// Compute dot product without decompression.
|
||||
float actual = 0.0f;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
|
|
|
|||
Loading…
Reference in New Issue