mirror of https://github.com/google/gemma.cpp.git
Use HWY_EXPORT_AND_DYNAMIC_DISPATCH_T where possible.
This commit is contained in:
parent
8567978541
commit
3b4fa4a0e3
112
gemma/gemma.cc
112
gemma/gemma.cc
|
|
@ -1106,12 +1106,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(const ByteStorageT& weights_u8,
|
||||
ByteStorageT& inference_state_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
void GenerateGemma(const ByteStorageT& weights_u8,
|
||||
ByteStorageT& inference_state_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
const WeightsF<TConfig>& weights =
|
||||
*reinterpret_cast<const WeightsF<TConfig>*>(weights_u8.get());
|
||||
InferenceState<TConfig>& inference_state =
|
||||
|
|
@ -1121,28 +1121,6 @@ void GenerateImpl(const ByteStorageT& weights_u8,
|
|||
prompt, pos, kv_cache, pool, timing_info, layers_output);
|
||||
}
|
||||
|
||||
void GenerateImplT(Model model, const ByteStorageT& weights_u8,
|
||||
ByteStorageT& inference_state_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
GenerateImpl<ConfigGemma2B>(
|
||||
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
|
||||
pool, timing_info, layers_output);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
GenerateImpl<ConfigGemmaTiny>(
|
||||
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
|
||||
pool, timing_info, layers_output);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
|
||||
|
||||
template <class TConfig>
|
||||
|
|
@ -1348,23 +1326,6 @@ void CompressWeights(const Path& weights_path,
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
}
|
||||
|
||||
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -1372,9 +1333,6 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CompressWeightsT);
|
||||
HWY_EXPORT(GenerateImplT);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||
KVCache kv_cache = {};
|
||||
|
|
@ -1486,14 +1444,46 @@ void GenerateGemma(Model model, const ByteStorageT& weights,
|
|||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImplT)(
|
||||
model, weights, inference_state, runtime_config, prompt, start_pos,
|
||||
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma2B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma7B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGriffin2B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemmaTiny>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model, weights, pool);
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return LoadWeights<ConfigGemmaTiny>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
ByteStorageT AllocateInferenceState(Model model) {
|
||||
|
|
@ -1513,8 +1503,22 @@ ByteStorageT AllocateInferenceState(Model model) {
|
|||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||
HWY_DYNAMIC_DISPATCH(CompressWeightsT)
|
||||
(model, weights, compressed_weights, pool);
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma2B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma7B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGriffin2B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
|
|
|
|||
Loading…
Reference in New Issue