Use HWY_EXPORT_AND_DYNAMIC_DISPATCH_T where possible.

This commit is contained in:
Zoltan Szabadka 2024-06-04 09:18:56 +00:00
parent 8567978541
commit 3b4fa4a0e3
1 changed files with 58 additions and 54 deletions

View File

@ -1106,12 +1106,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
}
template <class TConfig>
void GenerateImpl(const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
void GenerateGemma(const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
const WeightsF<TConfig>& weights =
*reinterpret_cast<const WeightsF<TConfig>*>(weights_u8.get());
InferenceState<TConfig>& inference_state =
@ -1121,28 +1121,6 @@ void GenerateImpl(const ByteStorageT& weights_u8,
prompt, pos, kv_cache, pool, timing_info, layers_output);
}
void GenerateImplT(Model model, const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
switch (model) {
case Model::GEMMA_2B:
GenerateImpl<ConfigGemma2B>(
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info, layers_output);
break;
case Model::GEMMA_TINY:
GenerateImpl<ConfigGemmaTiny>(
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info, layers_output);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
template <class TConfig>
@ -1348,23 +1326,6 @@ void CompressWeights(const Path& weights_path,
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
}
void CompressWeightsT(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
break;
case Model::GEMMA_7B:
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
break;
case Model::GRIFFIN_2B:
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
@ -1372,9 +1333,6 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(GenerateImplT);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {};
@ -1486,14 +1444,46 @@ void GenerateGemma(Model model, const ByteStorageT& weights,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
HWY_DYNAMIC_DISPATCH(GenerateImplT)(
model, weights, inference_state, runtime_config, prompt, start_pos,
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
switch (model) {
case Model::GEMMA_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma2B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GEMMA_7B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma7B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GRIFFIN_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGriffin2B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GEMMA_TINY:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemmaTiny>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model, weights, pool);
switch (model) {
case Model::GEMMA_2B:
return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
case Model::GEMMA_TINY:
return LoadWeights<ConfigGemmaTiny>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
ByteStorageT AllocateInferenceState(Model model) {
@ -1513,8 +1503,22 @@ ByteStorageT AllocateInferenceState(Model model) {
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool) {
HWY_DYNAMIC_DISPATCH(CompressWeightsT)
(model, weights, compressed_weights, pool);
switch (model) {
case Model::GEMMA_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma2B>)(
weights, compressed_weights, pool);
break;
case Model::GEMMA_7B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma7B>)(
weights, compressed_weights, pool);
break;
case Model::GRIFFIN_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGriffin2B>)(
weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,