Minor cleanup/fixes:

- optimize_test simplify prompt check
- Fix SFP arg case
- Fix includes
- Align inputs in test
- IsInside: add DASSERT
- Fix PerClusterPool NumThreads

PiperOrigin-RevId: 672530385
This commit is contained in:
Jan Wassenberg 2024-09-09 06:57:29 -07:00 committed by Copybara-Service
parent c29e9752c7
commit 5c0da8c8c3
7 changed files with 56 additions and 25 deletions

View File

@ -86,13 +86,11 @@ TEST(OptimizeTest, GradientDescent) {
// 1) Its length should be greater than the prompt. // 1) Its length should be greater than the prompt.
// 2) The prompt should be a prefix of the reply. // 2) The prompt should be a prefix of the reply.
auto verify = [&](const Prompt& prompt) { auto verify = [&](const Prompt& prompt) {
auto context = prompt.context(); const std::vector<int>& context = prompt.context();
std::vector<int> reply = generate(context); std::vector<int> reply = generate(context);
bool ok = true; if (reply.size() <= context.size()) return false;
ok &= (reply.size() > context.size()); return std::equal(context.begin(), context.end(), reply.begin(),
ok &= std::equal(prompt.tokens.begin(), prompt.tokens.end(), reply.begin() + context.size());
reply.begin(), reply.begin() + prompt.tokens.size());
return ok;
}; };
RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen); RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen);

View File

@ -43,8 +43,8 @@ namespace gcpp {
static inline const char* TypeName(float) { return "f32"; } static inline const char* TypeName(float) { return "f32"; }
static inline const char* TypeName(BF16) { return "b16"; } static inline const char* TypeName(BF16) { return "b16"; }
static inline const char* TypeName(SfpStream) { return "SFP"; } static inline const char* TypeName(SfpStream) { return "sfp"; }
static inline const char* TypeName(NuqStream) { return "NUQ"; } static inline const char* TypeName(NuqStream) { return "nuq"; }
// Returns the number of `MatT` elements required to store `capacity` values, // Returns the number of `MatT` elements required to store `capacity` values,
// which must not be zero. // which must not be zero.

View File

@ -1,12 +1,17 @@
#include "compression/python/compression_clif_aux.h" #include "compression/python/compression_clif_aux.h"
#include <string>
#include <vector>
#include "compression/compress.h"
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"compression/python/compression_clif_aux.cc" // NOLINT "compression/python/compression_clif_aux.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first. // compile pass, whereas we want this defined in the first.

View File

@ -463,10 +463,10 @@ HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
template <typename VecT, typename WeightT, typename OutT> template <typename VecT, typename WeightT, typename OutT>
void TestRMSNorm(hwy::RandomState& rng) { void TestRMSNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128; constexpr size_t kSize = 128;
VecT vec[kSize]; HWY_ALIGN VecT vec[kSize];
WeightT weight[kSize]; HWY_ALIGN WeightT weight[kSize];
OutT expected[kSize]; HWY_ALIGN OutT expected[kSize];
OutT actual[kSize]; HWY_ALIGN OutT actual[kSize];
for (size_t i = 0; i < kSize; ++i) { for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng)); vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));

View File

@ -152,7 +152,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"gr2b-pt = griffin 2B parameters, pretrained\n " "gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"), visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n" "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
" Required argument."); " Required argument.");
} }

View File

@ -51,6 +51,7 @@ HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
// Returns true if val is inside [min, max]. // Returns true if val is inside [min, max].
template <typename T> template <typename T>
static inline bool IsInside(T expected_min, T expected_max, T val) { static inline bool IsInside(T expected_min, T expected_max, T val) {
HWY_DASSERT(expected_min <= expected_max);
return expected_min <= val && val <= expected_max; return expected_min <= val && val <= expected_max;
} }

View File

@ -106,12 +106,36 @@ class PerClusterPools {
} }
} }
// The defaults for `AppArgs` `max_clusters` and `num_threads` are zero, which // `user_max_or_zero` == 0 means no limit, which is the case for the defaults
// means no limit. // of `AppArgs` `max_clusters` and `num_threads`.
size_t CapIfNonzero(size_t detected, size_t user_max_or_zero) { static inline size_t CapIfNonZero(size_t num_workers,
size_t user_max_or_zero) {
return (user_max_or_zero == 0) ? num_workers
: HWY_MIN(num_workers, user_max_or_zero);
}
// Returns the number of threads for `ThreadPool` to create: zero if there is
// no threading support, otherwise the capped number of workers minus the
// caller of `ThreadPool::Run`, which is the outer worker or main thread.
size_t CappedNumThreads(size_t num_workers, size_t user_max_or_zero) const {
if (!have_threading_support_) return 0; if (!have_threading_support_) return 0;
return (user_max_or_zero == 0) ? detected const size_t capped_num_workers =
: HWY_MIN(detected, user_max_or_zero); CapIfNonZero(num_workers, user_max_or_zero);
// Avoid underflow if number of workers is zero.
return capped_num_workers == 0 ? 0 : capped_num_workers - 1;
}
// Returns the number of workers for the inner pool whose index is `outer`, or
// 0 to indicate no limit if `max_threads` is zero.
size_t MaxInnerWorkers(const size_t max_threads, const size_t outer_workers,
const size_t outer) const {
HWY_DASSERT(outer < outer_workers);
if (max_threads == 0) return 0; // no limit
// Round down so we do not exceed the max.
const size_t max_threads_per_outer = max_threads / outer_workers;
// First outer pool gets the remainder.
const size_t remainder = (outer == 0) ? (max_threads % outer_workers) : 0;
return 1 + max_threads_per_outer + remainder;
} }
public: public:
@ -120,19 +144,21 @@ class PerClusterPools {
// result in threads not running on their own core, we only allow for // result in threads not running on their own core, we only allow for
// *upper bounds* on the number of clusters and threads. The actual number of // *upper bounds* on the number of clusters and threads. The actual number of
// clusters and threads are still limited by the detected topology. // clusters and threads are still limited by the detected topology.
// `max_threads` is the upper bound on threads to distribute among clusters,
// not including the one outer thread per cluster.
// //
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically. // `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1) PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1)
: have_threading_support_(hwy::HaveThreadingSupport()), : have_threading_support_(hwy::HaveThreadingSupport()),
cores_per_cluster_(DetectCoresPerCluster()), cores_per_cluster_(DetectCoresPerCluster()),
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) { outer_pool_(CappedNumThreads(cores_per_cluster_.size(), max_clusters)) {
// Topology detection failed - it currently requires Linux. // Topology detection failed - it currently requires Linux.
if (cores_per_cluster_.empty()) { if (cores_per_cluster_.empty()) {
// Create a single inner pool with up to TotalLogicalProcessors() / 2 // Create a single inner pool with up to TotalLogicalProcessors() / 2
// workers, further limited by `max_threads` if nonzero, and then pin to // workers, further limited by `max_threads` if nonzero, and then pin to
// the first N processors, which are typically on the first socket. // the first N processors, which are typically on the first socket.
const size_t num_threads = const size_t num_threads =
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads); CappedNumThreads(hwy::TotalLogicalProcessors() / 2, max_threads);
if (pin == -1) pin = num_threads > 8; if (pin == -1) pin = num_threads > 8;
fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n", fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n",
num_threads, pin); num_threads, pin);
@ -146,10 +172,11 @@ class PerClusterPools {
return; return;
} }
const size_t max_per_inner = max_threads / outer_pool_.NumWorkers();
for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) { for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) {
const size_t num_threads = const size_t max_inner_workers =
CapIfNonzero(cores_per_cluster_[outer].Count(), max_per_inner); MaxInnerWorkers(max_threads, outer_pool_.NumWorkers(), outer);
const size_t num_threads = CappedNumThreads(
cores_per_cluster_[outer].Count(), max_inner_workers);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads)); inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
} }