diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 3cfceb3..f6f7a58 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -806,16 +806,16 @@ Activations& GetActivations( } // namespace -// Placeholder for internal test3, do not remove - -bool StreamToken(size_t query_idx, size_t pos, int token, float weight, +bool StreamToken(size_t query_idx, size_t pos, int token, float prob, const RuntimeConfig& runtime_config) { if (runtime_config.batch_stream_token) { - return runtime_config.batch_stream_token(query_idx, pos, token, weight); + return runtime_config.batch_stream_token(query_idx, pos, token, prob); } - return runtime_config.stream_token(token, weight); + return runtime_config.stream_token(token, prob); } +// Placeholder for internal test3, do not remove + template void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, diff --git a/util/app.h b/util/app.h index 9735d26..a5a7dfd 100644 --- a/util/app.h +++ b/util/app.h @@ -133,8 +133,7 @@ class AppArgs : public ArgsBase { } static inline size_t GetSupportedThreadCount() { - return std::clamp(hwy::ThreadPool::MaxThreads(), size_t{1}, - std::min(kMaxThreads, size_t{18})); + return std::min(hwy::ThreadPool::MaxThreads(), kMaxThreads); } Path log; // output