Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T

PiperOrigin-RevId: 639793810
This commit is contained in:
Paul Chang 2024-06-03 08:31:47 -07:00 committed by Copybara-Service
parent a44cbdadc2
commit ed8f39c058
1 changed files with 28 additions and 149 deletions

View File

@ -1257,56 +1257,6 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
#undef TOKEN #undef TOKEN
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null // Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to // if weights = null, which happens during the first call where we attempt to
// load from cache. // load from cache.
@ -1416,36 +1366,6 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
return c_weights_u8; return c_weights_u8;
} }
// Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
const Path& weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
template <class TConfig> template <class TConfig>
void CompressWeights(const Path& weights_path, void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, const Path& compressed_weights_path,
@ -1501,15 +1421,7 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE #if HWY_ONCE
namespace gcpp { namespace gcpp {
HWY_EXPORT(LoadCompressedWeightsT);
HWY_EXPORT(LoadWeightsT);
HWY_EXPORT(CompressWeightsT); HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
HWY_EXPORT(GenerateGriffin2B);
HWY_EXPORT(ComputeCrossEntropy2B);
HWY_EXPORT(ComputeCrossEntropy7B);
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) { size_t conv1d_cache_size, size_t rglru_cache_size) {
@ -1540,68 +1452,31 @@ GemmaImpl<Config>::GemmaImpl(
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {} state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
template <> template <typename Config>
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config, void GemmaImpl<Config>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& pool, TimingInfo& timing_info,
TimingInfo& timing_info, LayersOutputT* layers_output) {
LayersOutputT* layers_output) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImpl<Config>)
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info, (*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output); layers_output);
} }
template <> template <typename Config>
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config, float GemmaImpl<Config>::ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& pool,
TimingInfo& timing_info, int verbosity) {
LayersOutputT* layers_output) { HWY_EXPORT_T(ComputeCrossEntropyT, ComputeCrossEntropyImpl<Config>);
HWY_DYNAMIC_DISPATCH(Generate7B) return HWY_DYNAMIC_DISPATCH_T(ComputeCrossEntropyT)(
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
*this, max_tokens, prompt, kv_cache, pool, verbosity); *this, max_tokens, prompt, kv_cache, pool, verbosity);
} }
template <> template <class Config>
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy( GemmaImpl<Config>* CreateGemmaImpl(const Path& tokenizer_path,
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache, const Path& weights, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
*this, max_tokens, prompt, kv_cache, pool, verbosity);
}
template <>
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
*this, max_tokens, prompt, kv_cache, pool, verbosity);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{ {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
@ -1613,21 +1488,25 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8; hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
if constexpr (kWeightsAreCompressed) { if constexpr (kWeightsAreCompressed) {
weights_u8 = HWY_EXPORT_T(LoadCompressedWeightsT, LoadCompressedWeights<Config>);
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool); weights_u8 = HWY_DYNAMIC_DISPATCH_T(LoadCompressedWeightsT)(weights, pool);
} else { } else {
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool); weights_u8 = LoadWeights<Config>(weights, pool);
} }
return new GemmaImpl<Config>(tokenizer, weights_u8, pool);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) {
switch (model_type) { switch (model_type) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool)); impl_.reset(CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
break; break;
case Model::GEMMA_7B: case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool)); impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
break; break;
case Model::GRIFFIN_2B: case Model::GRIFFIN_2B:
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool)); impl_.reset(CreateGemmaImpl<ConfigGriffin2B>(tokenizer_path, weights, pool));
break; break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));