From afd82376a5c11b70bd852ceb2c4b41ca39d24518 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 05:58:08 -0700 Subject: [PATCH] Add AES-CTR RNG for parallel sampling (not yet used) PiperOrigin-RevId: 802991142 --- BUILD.bazel | 15 ++++++ CMakeLists.txt | 2 + util/basics.cc | 75 ++++++++++++++++++++++++++++++ util/basics.h | 36 +++++++++++++++ util/basics_test.cc | 108 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 236 insertions(+) create mode 100644 util/basics.cc create mode 100644 util/basics_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index ce4cffb..62f2f5c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -29,9 +29,24 @@ exports_files([ cc_library( name = "basics", + srcs = ["util/basics.cc"], hdrs = ["util/basics.h"], deps = [ "@highway//:hwy", + "@highway//:timer", + "@highway//hwy/contrib/sort:vqsort", + ], +) + +cc_test( + name = "basics_test", + srcs = ["util/basics_test.cc"], + deps = [ + ":basics", + "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:timer", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8309840..d3a66fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,6 +120,7 @@ set(SOURCES paligemma/image.h util/allocator.cc util/allocator.h + util/basics.cc util/basics.h util/mat.cc util/mat.h @@ -227,6 +228,7 @@ set(GEMMA_TEST_FILES ops/ops_test.cc paligemma/image_test.cc paligemma/paligemma_test.cc + util/basics_test.cc util/threading_test.cc ) diff --git a/util/basics.cc b/util/basics.cc new file mode 100644 index 0000000..4261510 --- /dev/null +++ b/util/basics.cc @@ -0,0 +1,75 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/basics.h" + +#include +#include + +#include "hwy/contrib/sort/vqsort.h" +#include "hwy/highway.h" +#include "hwy/timer.h" + +namespace gcpp { + +RNG::RNG(bool deterministic) { + // Pi-based nothing up my sleeve numbers from Randen. + key_[0] = 0x243F6A8885A308D3ull; + key_[1] = 0x13198A2E03707344ull; + + if (!deterministic) { // want random seed + if (!hwy::Fill16BytesSecure(key_)) { + HWY_WARN("Failed to fill RNG key with secure random bits"); + // Entropy not available. The test requires that we inject some + // differences relative to the deterministic seeds. + key_[0] ^= reinterpret_cast(this); + key_[1] ^= hwy::timer::Start(); + } + } + + // Simple key schedule: swap and add constant (also from Randen). + for (size_t i = 0; i < kRounds; ++i) { + key_[2 + 2 * i + 0] = key_[2 * i + 1] + 0xA4093822299F31D0ull; + key_[2 + 2 * i + 1] = key_[2 * i + 0] + 0x082EFA98EC4E6C89ull; + } +} + +namespace hn = hwy::HWY_NAMESPACE; +using D = hn::Full128; // 128 bits for AES +using V = hn::Vec; + +static V Load(const uint64_t* ptr) { + return hn::Load(D(), reinterpret_cast(ptr)); +} + +RNG::result_type RNG::operator()() { + V state = Load(counter_); + counter_[0]++; + state = hn::Xor(state, Load(key_)); // initial whitening + + static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t)); + state = hn::AESRound(state, Load(key_ + 2)); + state = hn::AESRound(state, Load(key_ + 4)); + state = hn::AESRound(state, Load(key_ + 6)); + state = hn::AESRound(state, Load(key_ + 8)); + // Final round: fine to use another AESRound, including MixColumns. + state = hn::AESRound(state, Load(key_ + 10)); + + // Return lower 64 bits of the u8 vector. + const hn::Repartition d64; + return hn::GetLane(hn::BitCast(d64, state)); +} + +} // namespace gcpp diff --git a/util/basics.h b/util/basics.h index c8858e5..2429c72 100644 --- a/util/basics.h +++ b/util/basics.h @@ -119,6 +119,42 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end, size_t max_size) { return IndexRange(begin, HWY_MIN(begin + max_size, end)); } + +// Non-cryptographic 64-bit pseudo-random number generator. Supports random or +// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`. +// +// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This +// is useful for parallel sampling. Each thread can generate the stream for a +// particular task, without caring about prior/subsequent generations. +class alignas(16) RNG { + // "Large-scale randomness study of security margins for 100+ cryptographic + // functions": at least four. + // "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant. + static constexpr size_t kRounds = 5; + + public: + explicit RNG(bool deterministic); + + void SetStream(uint64_t stream) { + counter_[1] = stream; + counter_[0] = 0; + } + + using result_type = uint64_t; + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return ~result_type{0}; } + + // About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via + // unrolling by the AES latency (4-7 cycles). `std::discrete_distribution` + // makes individual calls to the generator, which would require buffering, + // which is not worth the complexity. + result_type operator()(); + + private: + uint64_t counter_[2] = {}; + uint64_t key_[2 * (1 + kRounds)]; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ diff --git a/util/basics_test.cc b/util/basics_test.cc new file mode 100644 index 0000000..169d051 --- /dev/null +++ b/util/basics_test.cc @@ -0,0 +1,108 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/basics.h" + +#include +#include + +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/timer.h" + +namespace gcpp { +namespace { + +TEST(BasicsTest, IsDeterministic) { + RNG rng1(/*deterministic=*/true); + RNG rng2(/*deterministic=*/true); + // Remember for later testing after resetting the stream. + const uint64_t r0 = rng1(); + const uint64_t r1 = rng1(); + // Not consecutive values. This could actually happen due to the extra XOR, + // but given the deterministic seeding here, we know it will not. + HWY_ASSERT(r0 != r1); + // Let rng2 catch up. + HWY_ASSERT(r0 == rng2()); + HWY_ASSERT(r1 == rng2()); + + for (size_t i = 0; i < 1000; ++i) { + HWY_ASSERT(rng1() == rng2()); + } + + // Reset counter, ensure it matches the default-constructed RNG. + rng1.SetStream(0); + HWY_ASSERT(r0 == rng1()); + HWY_ASSERT(r1 == rng1()); +} + +TEST(BasicsTest, IsSeeded) { + RNG rng1(/*deterministic=*/true); + RNG rng2(/*deterministic=*/false); + // It would be very unlucky to have even one 64-bit value match, and two are + // extremely unlikely. + const uint64_t a0 = rng1(); + const uint64_t a1 = rng1(); + const uint64_t b0 = rng2(); + const uint64_t b1 = rng2(); + HWY_ASSERT(a0 != b0 || a1 != b1); +} + +// If not close to 50% 1-bits, the RNG is quite broken. +TEST(BasicsTest, BitDistribution) { + RNG rng(/*deterministic=*/true); + constexpr size_t kU64 = 2 * 1000 * 1000; + const hwy::Timestamp t0; + uint64_t one_bits = 0; + for (size_t i = 0; i < kU64; ++i) { + one_bits += hwy::PopCount(rng()); + } + const uint64_t total_bits = kU64 * 64; + const double one_ratio = static_cast(one_bits) / total_bits; + const double elapsed = hwy::SecondsSince(t0); + fprintf(stderr, "1-bit ratio %.5f, %.1f M/s\n", one_ratio, + kU64 / elapsed * 1E-6); + HWY_ASSERT(0.4999 <= one_ratio && one_ratio <= 0.5001); +} + +TEST(BasicsTest, ChiSquared) { + RNG rng(/*deterministic=*/true); + constexpr size_t kU64 = 1 * 1000 * 1000; + + // Test each byte separately. + for (size_t shift = 0; shift < 64; shift += 8) { + size_t counts[256] = {}; + for (size_t i = 0; i < kU64; ++i) { + const size_t byte = (rng() >> shift) & 0xFF; + counts[byte]++; + } + + double chi_squared = 0.0; + const double expected = static_cast(kU64) / 256.0; + for (size_t i = 0; i < 256; ++i) { + const double diff = static_cast(counts[i]) - expected; + chi_squared += diff * diff / expected; + } + // Should be within ~0.5% and 99.5% percentiles. See + // https://www.medcalc.org/manual/chi-square-table.php + if (chi_squared < 196.0 || chi_squared > 311.0) { + HWY_ABORT("Chi-squared byte %zu: %.5f \n", shift / 8, chi_squared); + } + } +} + +} // namespace +} // namespace gcpp +HWY_TEST_MAIN();