diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 504e1e0..52f2656 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -1257,56 +1257,6 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, #undef TOKEN -void Generate2B(GemmaImpl& gemma, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info, LayersOutputT* layers_output) { - GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, - timing_info, layers_output); -} - -void Generate7B(GemmaImpl& gemma, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info, LayersOutputT* layers_output) { - GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, - timing_info, layers_output); -} - -void GenerateGriffin2B(GemmaImpl& gemma, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info, LayersOutputT* layers_output) { - GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool, - timing_info, layers_output); -} - -float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, - const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity) { - return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - verbosity); -} - -float ComputeCrossEntropy7B(GemmaImpl& gemma, size_t max_tokens, - const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity) { - return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - verbosity); -} - -float ComputeCrossEntropyGriffin2B(GemmaImpl& gemma, - size_t max_tokens, - const std::vector& prompt, - KVCache& kv_cache, hwy::ThreadPool& pool, - int verbosity) { - return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - verbosity); -} - // Calls func(name, float*, CompressedArray&) for each tensor. float* is null // if weights = null, which happens during the first call where we attempt to // load from cache. @@ -1416,36 +1366,6 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeights( return c_weights_u8; } -// Type-erased because this function is called via a function pointer. -hwy::AlignedFreeUniquePtr LoadCompressedWeightsT( - gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) { - switch (model) { - case Model::GEMMA_2B: - return LoadCompressedWeights(weights, pool); - case Model::GEMMA_7B: - return LoadCompressedWeights(weights, pool); - case Model::GRIFFIN_2B: - return LoadCompressedWeights(weights, pool); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - -hwy::AlignedFreeUniquePtr LoadWeightsT(gcpp::Model model, - const Path& weights, - hwy::ThreadPool& pool) { - switch (model) { - case Model::GEMMA_2B: - return LoadWeights(weights, pool); - case Model::GEMMA_7B: - return LoadWeights(weights, pool); - case Model::GRIFFIN_2B: - return LoadWeights(weights, pool); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, @@ -1501,15 +1421,7 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -HWY_EXPORT(LoadCompressedWeightsT); -HWY_EXPORT(LoadWeightsT); HWY_EXPORT(CompressWeightsT); -HWY_EXPORT(Generate2B); -HWY_EXPORT(Generate7B); -HWY_EXPORT(GenerateGriffin2B); -HWY_EXPORT(ComputeCrossEntropy2B); -HWY_EXPORT(ComputeCrossEntropy7B); -HWY_EXPORT(ComputeCrossEntropyGriffin2B); KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { @@ -1540,68 +1452,31 @@ GemmaImpl::GemmaImpl( prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()) {} -template <> -void GemmaImpl::Generate(const RuntimeConfig& runtime_config, - const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, - TimingInfo& timing_info, - LayersOutputT* layers_output) { - HWY_DYNAMIC_DISPATCH(Generate2B) +template +void GemmaImpl::Generate(const RuntimeConfig& runtime_config, + const std::vector& prompt, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info, + LayersOutputT* layers_output) { + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImpl) (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, layers_output); } -template <> -void GemmaImpl::Generate(const RuntimeConfig& runtime_config, - const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, - TimingInfo& timing_info, - LayersOutputT* layers_output) { - HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, - layers_output); -} - -template <> -void GemmaImpl::Generate(const RuntimeConfig& runtime_config, - const std::vector& prompt, - size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, - TimingInfo& timing_info, - LayersOutputT* layers_output) { - HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) - (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, - layers_output); -} - -template <> -float GemmaImpl::ComputeCrossEntropy( - size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity) { - return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)( +template +float GemmaImpl::ComputeCrossEntropy(size_t max_tokens, + const std::vector& prompt, + KVCache& kv_cache, + hwy::ThreadPool& pool, + int verbosity) { + HWY_EXPORT_T(ComputeCrossEntropyT, ComputeCrossEntropyImpl); + return HWY_DYNAMIC_DISPATCH_T(ComputeCrossEntropyT)( *this, max_tokens, prompt, kv_cache, pool, verbosity); } -template <> -float GemmaImpl::ComputeCrossEntropy( - size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity) { - return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)( - *this, max_tokens, prompt, kv_cache, pool, verbosity); -} - -template <> -float GemmaImpl::ComputeCrossEntropy( - size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity) { - return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)( - *this, max_tokens, prompt, kv_cache, pool, verbosity); -} - -Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, - hwy::ThreadPool& pool) { +template +GemmaImpl* CreateGemmaImpl(const Path& tokenizer_path, + const Path& weights, hwy::ThreadPool& pool) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); @@ -1613,21 +1488,25 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, hwy::AlignedFreeUniquePtr weights_u8; if constexpr (kWeightsAreCompressed) { - weights_u8 = - HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool); + HWY_EXPORT_T(LoadCompressedWeightsT, LoadCompressedWeights); + weights_u8 = HWY_DYNAMIC_DISPATCH_T(LoadCompressedWeightsT)(weights, pool); } else { - weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool); + weights_u8 = LoadWeights(weights, pool); } + return new GemmaImpl(tokenizer, weights_u8, pool); +} +Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, + hwy::ThreadPool& pool) { switch (model_type) { case Model::GEMMA_2B: - impl_.reset(new GemmaImpl(tokenizer, weights_u8, pool)); + impl_.reset(CreateGemmaImpl(tokenizer_path, weights, pool)); break; case Model::GEMMA_7B: - impl_.reset(new GemmaImpl(tokenizer, weights_u8, pool)); + impl_.reset(CreateGemmaImpl(tokenizer_path, weights, pool)); break; case Model::GRIFFIN_2B: - impl_.reset(new GemmaImpl(tokenizer, weights_u8, pool)); + impl_.reset(CreateGemmaImpl(tokenizer_path, weights, pool)); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type));