Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T

PiperOrigin-RevId: 639793810
This commit is contained in:
Paul Chang 2024-06-03 08:31:47 -07:00 committed by Copybara-Service
parent a44cbdadc2
commit ed8f39c058
1 changed files with 28 additions and 149 deletions

View File

@ -1257,56 +1257,6 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
#undef TOKEN
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
timing_info, layers_output);
}
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to
// load from cache.
@ -1416,36 +1366,6 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
return c_weights_u8;
}
// Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
const Path& weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
template <class TConfig>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path,
@ -1501,15 +1421,7 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(LoadCompressedWeightsT);
HWY_EXPORT(LoadWeightsT);
HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
HWY_EXPORT(GenerateGriffin2B);
HWY_EXPORT(ComputeCrossEntropy2B);
HWY_EXPORT(ComputeCrossEntropy7B);
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) {
@ -1540,68 +1452,31 @@ GemmaImpl<Config>::GemmaImpl(
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
template <>
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
template <typename Config>
void GemmaImpl<Config>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B)
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImpl<Config>)
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
template <typename Config>
float GemmaImpl<Config>::ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
layers_output);
}
template <>
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
int verbosity) {
HWY_EXPORT_T(ComputeCrossEntropyT, ComputeCrossEntropyImpl<Config>);
return HWY_DYNAMIC_DISPATCH_T(ComputeCrossEntropyT)(
*this, max_tokens, prompt, kv_cache, pool, verbosity);
}
template <>
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
*this, max_tokens, prompt, kv_cache, pool, verbosity);
}
template <>
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
*this, max_tokens, prompt, kv_cache, pool, verbosity);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) {
template <class Config>
GemmaImpl<Config>* CreateGemmaImpl(const Path& tokenizer_path,
const Path& weights, hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{
PROFILER_ZONE("Startup.tokenizer");
@ -1613,21 +1488,25 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
if constexpr (kWeightsAreCompressed) {
weights_u8 =
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool);
HWY_EXPORT_T(LoadCompressedWeightsT, LoadCompressedWeights<Config>);
weights_u8 = HWY_DYNAMIC_DISPATCH_T(LoadCompressedWeightsT)(weights, pool);
} else {
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool);
weights_u8 = LoadWeights<Config>(weights, pool);
}
return new GemmaImpl<Config>(tokenizer, weights_u8, pool);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) {
switch (model_type) {
case Model::GEMMA_2B:
impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool));
impl_.reset(CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
break;
case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
break;
case Model::GRIFFIN_2B:
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
impl_.reset(CreateGemmaImpl<ConfigGriffin2B>(tokenizer_path, weights, pool));
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));