From b1c1ec1d59ab0fe0faed590e9f685a7103ec59d0 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 2 Jul 2024 03:26:29 -0700 Subject: [PATCH] Use benchmark_helper in py bindings (adds BOS) Also remove thread clamp (OK to be zero or large). PiperOrigin-RevId: 648657155 --- gemma/gemma.cc | 10 +++++----- util/app.h | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 3cfceb3..f6f7a58 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -806,16 +806,16 @@ Activations& GetActivations( } // namespace -// Placeholder for internal test3, do not remove - -bool StreamToken(size_t query_idx, size_t pos, int token, float weight, +bool StreamToken(size_t query_idx, size_t pos, int token, float prob, const RuntimeConfig& runtime_config) { if (runtime_config.batch_stream_token) { - return runtime_config.batch_stream_token(query_idx, pos, token, weight); + return runtime_config.batch_stream_token(query_idx, pos, token, prob); } - return runtime_config.stream_token(token, weight); + return runtime_config.stream_token(token, prob); } +// Placeholder for internal test3, do not remove + template void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, diff --git a/util/app.h b/util/app.h index 9735d26..a5a7dfd 100644 --- a/util/app.h +++ b/util/app.h @@ -133,8 +133,7 @@ class AppArgs : public ArgsBase { } static inline size_t GetSupportedThreadCount() { - return std::clamp(hwy::ThreadPool::MaxThreads(), size_t{1}, - std::min(kMaxThreads, size_t{18})); + return std::min(hwy::ThreadPool::MaxThreads(), kMaxThreads); } Path log; // output