mirror of https://github.com/google/gemma.cpp.git
Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T
PiperOrigin-RevId: 639793810
This commit is contained in:
parent
a44cbdadc2
commit
ed8f39c058
167
gemma/gemma.cc
167
gemma/gemma.cc
|
|
@ -1257,56 +1257,6 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
|
||||||
#undef TOKEN
|
#undef TOKEN
|
||||||
|
|
||||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
|
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
|
||||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
|
||||||
timing_info, layers_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
|
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
|
||||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
|
||||||
timing_info, layers_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
|
||||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
|
||||||
timing_info, layers_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, int verbosity) {
|
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
|
||||||
verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, int verbosity) {
|
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
|
||||||
verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
|
||||||
size_t max_tokens,
|
|
||||||
const std::vector<int>& prompt,
|
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
int verbosity) {
|
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
|
||||||
verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
// if weights = null, which happens during the first call where we attempt to
|
// if weights = null, which happens during the first call where we attempt to
|
||||||
// load from cache.
|
// load from cache.
|
||||||
|
|
@ -1416,36 +1366,6 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
||||||
return c_weights_u8;
|
return c_weights_u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type-erased because this function is called via a function pointer.
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
|
|
||||||
gcpp::Model model, const Path& weights, hwy::ThreadPool& pool) {
|
|
||||||
switch (model) {
|
|
||||||
case Model::GEMMA_2B:
|
|
||||||
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
|
||||||
case Model::GEMMA_7B:
|
|
||||||
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
|
||||||
case Model::GRIFFIN_2B:
|
|
||||||
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
|
|
||||||
default:
|
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
|
|
||||||
const Path& weights,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
switch (model) {
|
|
||||||
case Model::GEMMA_2B:
|
|
||||||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
|
||||||
case Model::GEMMA_7B:
|
|
||||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
|
||||||
case Model::GRIFFIN_2B:
|
|
||||||
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
|
||||||
default:
|
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void CompressWeights(const Path& weights_path,
|
void CompressWeights(const Path& weights_path,
|
||||||
const Path& compressed_weights_path,
|
const Path& compressed_weights_path,
|
||||||
|
|
@ -1501,15 +1421,7 @@ HWY_AFTER_NAMESPACE();
|
||||||
#if HWY_ONCE
|
#if HWY_ONCE
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(LoadCompressedWeightsT);
|
|
||||||
HWY_EXPORT(LoadWeightsT);
|
|
||||||
HWY_EXPORT(CompressWeightsT);
|
HWY_EXPORT(CompressWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
|
||||||
HWY_EXPORT(Generate7B);
|
|
||||||
HWY_EXPORT(GenerateGriffin2B);
|
|
||||||
HWY_EXPORT(ComputeCrossEntropy2B);
|
|
||||||
HWY_EXPORT(ComputeCrossEntropy7B);
|
|
||||||
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
|
|
||||||
|
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||||
|
|
@ -1540,68 +1452,31 @@ GemmaImpl<Config>::GemmaImpl(
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||||
|
|
||||||
template <>
|
template <typename Config>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
|
void GemmaImpl<Config>::Generate(const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt,
|
const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||||
TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImpl<Config>)
|
||||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <typename Config>
|
||||||
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
|
float GemmaImpl<Config>::ComputeCrossEntropy(size_t max_tokens,
|
||||||
const std::vector<int>& prompt,
|
const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache,
|
KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info,
|
int verbosity) {
|
||||||
LayersOutputT* layers_output) {
|
HWY_EXPORT_T(ComputeCrossEntropyT, ComputeCrossEntropyImpl<Config>);
|
||||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
return HWY_DYNAMIC_DISPATCH_T(ComputeCrossEntropyT)(
|
||||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
|
||||||
layers_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
|
|
||||||
const std::vector<int>& prompt,
|
|
||||||
size_t start_pos, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool,
|
|
||||||
TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
|
||||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
|
||||||
layers_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, int verbosity) {
|
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
|
||||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <class Config>
|
||||||
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
GemmaImpl<Config>* CreateGemmaImpl(const Path& tokenizer_path,
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
const Path& weights, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool, int verbosity) {
|
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
|
||||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, int verbosity) {
|
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
|
||||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
|
||||||
}
|
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
|
@ -1613,21 +1488,25 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
|
||||||
if constexpr (kWeightsAreCompressed) {
|
if constexpr (kWeightsAreCompressed) {
|
||||||
weights_u8 =
|
HWY_EXPORT_T(LoadCompressedWeightsT, LoadCompressedWeights<Config>);
|
||||||
HWY_DYNAMIC_DISPATCH(LoadCompressedWeightsT)(model_type, weights, pool);
|
weights_u8 = HWY_DYNAMIC_DISPATCH_T(LoadCompressedWeightsT)(weights, pool);
|
||||||
} else {
|
} else {
|
||||||
weights_u8 = HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model_type, weights, pool);
|
weights_u8 = LoadWeights<Config>(weights, pool);
|
||||||
}
|
}
|
||||||
|
return new GemmaImpl<Config>(tokenizer, weights_u8, pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGemma2B>(tokenizer, weights_u8, pool));
|
impl_.reset(CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
|
||||||
break;
|
break;
|
||||||
case Model::GRIFFIN_2B:
|
case Model::GRIFFIN_2B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
|
impl_.reset(CreateGemmaImpl<ConfigGriffin2B>(tokenizer_path, weights, pool));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue