Use benchmark_helper in py bindings (adds BOS)

Also remove thread clamp (OK to be zero or large).

PiperOrigin-RevId: 648657155
This commit is contained in:
Jan Wassenberg 2024-07-02 03:26:29 -07:00 committed by Copybara-Service
parent e527e7662e
commit b1c1ec1d59
2 changed files with 6 additions and 7 deletions

View File

@ -806,16 +806,16 @@ Activations<TConfig, kBatchSize>& GetActivations(
} // namespace
// Placeholder for internal test3, do not remove
bool StreamToken(size_t query_idx, size_t pos, int token, float weight,
bool StreamToken(size_t query_idx, size_t pos, int token, float prob,
const RuntimeConfig& runtime_config) {
if (runtime_config.batch_stream_token) {
return runtime_config.batch_stream_token(query_idx, pos, token, weight);
return runtime_config.batch_stream_token(query_idx, pos, token, prob);
}
return runtime_config.stream_token(token, weight);
return runtime_config.stream_token(token, prob);
}
// Placeholder for internal test3, do not remove
template <class TConfig, size_t kQueryBatchSize>
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,

View File

@ -133,8 +133,7 @@ class AppArgs : public ArgsBase<AppArgs> {
}
static inline size_t GetSupportedThreadCount() {
return std::clamp(hwy::ThreadPool::MaxThreads(), size_t{1},
std::min(kMaxThreads, size_t{18}));
return std::min(hwy::ThreadPool::MaxThreads(), kMaxThreads);
}
Path log; // output