From 3b4fa4a0e3678eb6751c6920bf053c81429d36e6 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Tue, 4 Jun 2024 09:18:56 +0000 Subject: [PATCH] Use HWY_EXPORT_AND_DYNAMIC_DISPATCH_T where possible. --- gemma/gemma.cc | 112 +++++++++++++++++++++++++------------------------ 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 888e447..d168a9d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -1106,12 +1106,12 @@ void GenerateImpl(GemmaImpl& gemma, } template -void GenerateImpl(const ByteStorageT& weights_u8, - ByteStorageT& inference_state_u8, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info, LayersOutputT* layers_output) { +void GenerateGemma(const ByteStorageT& weights_u8, + ByteStorageT& inference_state_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) { const WeightsF& weights = *reinterpret_cast*>(weights_u8.get()); InferenceState& inference_state = @@ -1121,28 +1121,6 @@ void GenerateImpl(const ByteStorageT& weights_u8, prompt, pos, kv_cache, pool, timing_info, layers_output); } -void GenerateImplT(Model model, const ByteStorageT& weights_u8, - ByteStorageT& inference_state_u8, - const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info, LayersOutputT* layers_output) { - switch (model) { - case Model::GEMMA_2B: - GenerateImpl( - weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache, - pool, timing_info, layers_output); - break; - case Model::GEMMA_TINY: - GenerateImpl( - weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache, - pool, timing_info, layers_output); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - #define TOKEN(token_id) TokenString(gemma, token_id).c_str() template @@ -1348,23 +1326,6 @@ void CompressWeights(const Path& weights_path, c_weights->c_layer_ptrs.~CompressedLayerPointers(); } -void CompressWeightsT(gcpp::Model model, const Path& weights, - const Path& compressed_weights, hwy::ThreadPool& pool) { - switch (model) { - case Model::GEMMA_2B: - CompressWeights(weights, compressed_weights, pool); - break; - case Model::GEMMA_7B: - CompressWeights(weights, compressed_weights, pool); - break; - case Model::GRIFFIN_2B: - CompressWeights(weights, compressed_weights, pool); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -1372,9 +1333,6 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -HWY_EXPORT(CompressWeightsT); -HWY_EXPORT(GenerateImplT); - KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; @@ -1486,14 +1444,46 @@ void GenerateGemma(Model model, const ByteStorageT& weights, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { - HWY_DYNAMIC_DISPATCH(GenerateImplT)( - model, weights, inference_state, runtime_config, prompt, start_pos, - kv_cache, pool, timing_info, /*layers_output=*/nullptr); + switch (model) { + case Model::GEMMA_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GEMMA_7B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GRIFFIN_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GEMMA_TINY: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } } ByteStorageT LoadWeights(const Path& weights, Model model, hwy::ThreadPool& pool) { - return HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model, weights, pool); + switch (model) { + case Model::GEMMA_2B: + return LoadWeights(weights, pool); + case Model::GEMMA_7B: + return LoadWeights(weights, pool); + case Model::GRIFFIN_2B: + return LoadWeights(weights, pool); + case Model::GEMMA_TINY: + return LoadWeights(weights, pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } } ByteStorageT AllocateInferenceState(Model model) { @@ -1513,8 +1503,22 @@ ByteStorageT AllocateInferenceState(Model model) { void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool) { - HWY_DYNAMIC_DISPATCH(CompressWeightsT) - (model, weights, compressed_weights, pool); + switch (model) { + case Model::GEMMA_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + case Model::GEMMA_7B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + case Model::GRIFFIN_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } } float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,