mirror of https://github.com/google/gemma.cpp.git
Add AES-CTR RNG for parallel sampling (not yet used)
PiperOrigin-RevId: 802991142
This commit is contained in:
parent
4be4799727
commit
afd82376a5
15
BUILD.bazel
15
BUILD.bazel
|
|
@ -29,9 +29,24 @@ exports_files([
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "basics",
|
name = "basics",
|
||||||
|
srcs = ["util/basics.cc"],
|
||||||
hdrs = ["util/basics.h"],
|
hdrs = ["util/basics.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
"@highway//:timer",
|
||||||
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "basics_test",
|
||||||
|
srcs = ["util/basics_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:timer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ set(SOURCES
|
||||||
paligemma/image.h
|
paligemma/image.h
|
||||||
util/allocator.cc
|
util/allocator.cc
|
||||||
util/allocator.h
|
util/allocator.h
|
||||||
|
util/basics.cc
|
||||||
util/basics.h
|
util/basics.h
|
||||||
util/mat.cc
|
util/mat.cc
|
||||||
util/mat.h
|
util/mat.h
|
||||||
|
|
@ -227,6 +228,7 @@ set(GEMMA_TEST_FILES
|
||||||
ops/ops_test.cc
|
ops/ops_test.cc
|
||||||
paligemma/image_test.cc
|
paligemma/image_test.cc
|
||||||
paligemma/paligemma_test.cc
|
paligemma/paligemma_test.cc
|
||||||
|
util/basics_test.cc
|
||||||
util/threading_test.cc
|
util/threading_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "util/basics.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
RNG::RNG(bool deterministic) {
|
||||||
|
// Pi-based nothing up my sleeve numbers from Randen.
|
||||||
|
key_[0] = 0x243F6A8885A308D3ull;
|
||||||
|
key_[1] = 0x13198A2E03707344ull;
|
||||||
|
|
||||||
|
if (!deterministic) { // want random seed
|
||||||
|
if (!hwy::Fill16BytesSecure(key_)) {
|
||||||
|
HWY_WARN("Failed to fill RNG key with secure random bits");
|
||||||
|
// Entropy not available. The test requires that we inject some
|
||||||
|
// differences relative to the deterministic seeds.
|
||||||
|
key_[0] ^= reinterpret_cast<uint64_t>(this);
|
||||||
|
key_[1] ^= hwy::timer::Start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple key schedule: swap and add constant (also from Randen).
|
||||||
|
for (size_t i = 0; i < kRounds; ++i) {
|
||||||
|
key_[2 + 2 * i + 0] = key_[2 * i + 1] + 0xA4093822299F31D0ull;
|
||||||
|
key_[2 + 2 * i + 1] = key_[2 * i + 0] + 0x082EFA98EC4E6C89ull;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::Full128<uint8_t>; // 128 bits for AES
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
|
static V Load(const uint64_t* ptr) {
|
||||||
|
return hn::Load(D(), reinterpret_cast<const uint8_t*>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
RNG::result_type RNG::operator()() {
|
||||||
|
V state = Load(counter_);
|
||||||
|
counter_[0]++;
|
||||||
|
state = hn::Xor(state, Load(key_)); // initial whitening
|
||||||
|
|
||||||
|
static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t));
|
||||||
|
state = hn::AESRound(state, Load(key_ + 2));
|
||||||
|
state = hn::AESRound(state, Load(key_ + 4));
|
||||||
|
state = hn::AESRound(state, Load(key_ + 6));
|
||||||
|
state = hn::AESRound(state, Load(key_ + 8));
|
||||||
|
// Final round: fine to use another AESRound, including MixColumns.
|
||||||
|
state = hn::AESRound(state, Load(key_ + 10));
|
||||||
|
|
||||||
|
// Return lower 64 bits of the u8 vector.
|
||||||
|
const hn::Repartition<uint64_t, D> d64;
|
||||||
|
return hn::GetLane(hn::BitCast(d64, state));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -119,6 +119,42 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
|
||||||
size_t max_size) {
|
size_t max_size) {
|
||||||
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
|
||||||
|
// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`.
|
||||||
|
//
|
||||||
|
// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
|
||||||
|
// is useful for parallel sampling. Each thread can generate the stream for a
|
||||||
|
// particular task, without caring about prior/subsequent generations.
|
||||||
|
class alignas(16) RNG {
|
||||||
|
// "Large-scale randomness study of security margins for 100+ cryptographic
|
||||||
|
// functions": at least four.
|
||||||
|
// "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
|
||||||
|
static constexpr size_t kRounds = 5;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit RNG(bool deterministic);
|
||||||
|
|
||||||
|
void SetStream(uint64_t stream) {
|
||||||
|
counter_[1] = stream;
|
||||||
|
counter_[0] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
using result_type = uint64_t;
|
||||||
|
static constexpr result_type min() { return 0; }
|
||||||
|
static constexpr result_type max() { return ~result_type{0}; }
|
||||||
|
|
||||||
|
// About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via
|
||||||
|
// unrolling by the AES latency (4-7 cycles). `std::discrete_distribution`
|
||||||
|
// makes individual calls to the generator, which would require buffering,
|
||||||
|
// which is not worth the complexity.
|
||||||
|
result_type operator()();
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t counter_[2] = {};
|
||||||
|
uint64_t key_[2 * (1 + kRounds)];
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "util/basics.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(BasicsTest, IsDeterministic) {
|
||||||
|
RNG rng1(/*deterministic=*/true);
|
||||||
|
RNG rng2(/*deterministic=*/true);
|
||||||
|
// Remember for later testing after resetting the stream.
|
||||||
|
const uint64_t r0 = rng1();
|
||||||
|
const uint64_t r1 = rng1();
|
||||||
|
// Not consecutive values. This could actually happen due to the extra XOR,
|
||||||
|
// but given the deterministic seeding here, we know it will not.
|
||||||
|
HWY_ASSERT(r0 != r1);
|
||||||
|
// Let rng2 catch up.
|
||||||
|
HWY_ASSERT(r0 == rng2());
|
||||||
|
HWY_ASSERT(r1 == rng2());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < 1000; ++i) {
|
||||||
|
HWY_ASSERT(rng1() == rng2());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset counter, ensure it matches the default-constructed RNG.
|
||||||
|
rng1.SetStream(0);
|
||||||
|
HWY_ASSERT(r0 == rng1());
|
||||||
|
HWY_ASSERT(r1 == rng1());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BasicsTest, IsSeeded) {
|
||||||
|
RNG rng1(/*deterministic=*/true);
|
||||||
|
RNG rng2(/*deterministic=*/false);
|
||||||
|
// It would be very unlucky to have even one 64-bit value match, and two are
|
||||||
|
// extremely unlikely.
|
||||||
|
const uint64_t a0 = rng1();
|
||||||
|
const uint64_t a1 = rng1();
|
||||||
|
const uint64_t b0 = rng2();
|
||||||
|
const uint64_t b1 = rng2();
|
||||||
|
HWY_ASSERT(a0 != b0 || a1 != b1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not close to 50% 1-bits, the RNG is quite broken.
|
||||||
|
TEST(BasicsTest, BitDistribution) {
|
||||||
|
RNG rng(/*deterministic=*/true);
|
||||||
|
constexpr size_t kU64 = 2 * 1000 * 1000;
|
||||||
|
const hwy::Timestamp t0;
|
||||||
|
uint64_t one_bits = 0;
|
||||||
|
for (size_t i = 0; i < kU64; ++i) {
|
||||||
|
one_bits += hwy::PopCount(rng());
|
||||||
|
}
|
||||||
|
const uint64_t total_bits = kU64 * 64;
|
||||||
|
const double one_ratio = static_cast<double>(one_bits) / total_bits;
|
||||||
|
const double elapsed = hwy::SecondsSince(t0);
|
||||||
|
fprintf(stderr, "1-bit ratio %.5f, %.1f M/s\n", one_ratio,
|
||||||
|
kU64 / elapsed * 1E-6);
|
||||||
|
HWY_ASSERT(0.4999 <= one_ratio && one_ratio <= 0.5001);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BasicsTest, ChiSquared) {
|
||||||
|
RNG rng(/*deterministic=*/true);
|
||||||
|
constexpr size_t kU64 = 1 * 1000 * 1000;
|
||||||
|
|
||||||
|
// Test each byte separately.
|
||||||
|
for (size_t shift = 0; shift < 64; shift += 8) {
|
||||||
|
size_t counts[256] = {};
|
||||||
|
for (size_t i = 0; i < kU64; ++i) {
|
||||||
|
const size_t byte = (rng() >> shift) & 0xFF;
|
||||||
|
counts[byte]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
double chi_squared = 0.0;
|
||||||
|
const double expected = static_cast<double>(kU64) / 256.0;
|
||||||
|
for (size_t i = 0; i < 256; ++i) {
|
||||||
|
const double diff = static_cast<double>(counts[i]) - expected;
|
||||||
|
chi_squared += diff * diff / expected;
|
||||||
|
}
|
||||||
|
// Should be within ~0.5% and 99.5% percentiles. See
|
||||||
|
// https://www.medcalc.org/manual/chi-square-table.php
|
||||||
|
if (chi_squared < 196.0 || chi_squared > 311.0) {
|
||||||
|
HWY_ABORT("Chi-squared byte %zu: %.5f \n", shift / 8, chi_squared);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_TEST_MAIN();
|
||||||
Loading…
Reference in New Issue