mirror of https://github.com/google/gemma.cpp.git
Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T
PiperOrigin-RevId: 639793810
This commit is contained in:
parent
a44cbdadc2
commit
ed8f39c058
167
gemma/gemma.cc
167
gemma/gemma.cc
|
|
@ -1257,56 +1257,6 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
|
||||
#undef TOKEN
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||
size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
// if weights = null, which happens during the first call where we attempt to
|
||||
// load from cache.
|
||||
|
|
@ -1416,36 +1366,6 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
|||
return c_weights_u8;
|
||||
}
|
||||
|
||||
// Type-erased because this function is called via a function pointer.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
|
||||
gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
|
||||
const Path& weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path,
|
||||
|
|
@ -1501,15 +1421,7 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(LoadCompressedWeightsT);
|
||||
HWY_EXPORT(LoadWeightsT);
|
||||
HWY_EXPORT(CompressWeightsT);
|
||||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
HWY_EXPORT(GenerateGriffin2B);
|
||||
HWY_EXPORT(ComputeCrossEntropy2B);
|
||||
HWY_EXPORT(ComputeCrossEntropy7B);
|
||||
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||
|
|
@ -1540,68 +1452,31 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
|
||||
template <typename Config>
|
||||
void GemmaImpl<Config>::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImpl<Config>)
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
|
||||
template <typename Config>
|
||||
float GemmaImpl<Config>::ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
||||
int verbosity) {
|
||||
HWY_EXPORT_T(ComputeCrossEntropyT, ComputeCrossEntropyImpl<Config>);
|
||||
return HWY_DYNAMIC_DISPATCH_T(ComputeCrossEntropyT)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
template <class Config>
|
||||
GemmaImpl<Config>* CreateGemmaImpl(const Path& tokenizer_path,
|
||||
const Path& weights, hwy::ThreadPool& pool) {
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
{
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
|
|
@ -1613,21 +1488,25 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
|||
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||
if constexpr (kWeightsAreCompressed) {
|
||||
weights_u8 =
|
||||
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool);
|
||||
HWY_EXPORT_T(LoadCompressedWeightsT, LoadCompressedWeights<Config>);
|
||||
weights_u8 = HWY_DYNAMIC_DISPATCH_T(LoadCompressedWeightsT)(weights, pool);
|
||||
} else {
|
||||
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool);
|
||||
weights_u8 = LoadWeights<Config>(weights, pool);
|
||||
}
|
||||
return new GemmaImpl<Config>(tokenizer, weights_u8, pool);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model_type) {
|
||||
case Model::GEMMA_2B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool));
|
||||
impl_.reset(CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
||||
impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
|
||||
impl_.reset(CreateGemmaImpl<ConfigGriffin2B>(tokenizer_path, weights, pool));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||
|
|
|
|||
Loading…
Reference in New Issue