mirror of https://github.com/google/gemma.cpp.git
[WIP] remove args from GetWeights, GetCompressedWeights
This commit is contained in:
parent
3df06f64c2
commit
b67e28d1a0
26
gemma.cc
26
gemma.cc
|
|
@ -30,6 +30,7 @@
|
|||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
#include "util/app.h" // arg types
|
||||
#include "util/args.h" // Path
|
||||
|
||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||
|
|
@ -697,10 +698,10 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
|
||||
template <class TConfig>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
||||
const Path& model, const Path& cache, hwy::ThreadPool& pool) {
|
||||
const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.LoadCache");
|
||||
|
||||
if (!std::filesystem::exists(model.path) &&
|
||||
if (!std::filesystem::exists(weights_path.path) &&
|
||||
!std::filesystem::exists(cache.path)) {
|
||||
HWY_ABORT(
|
||||
"Either the model weights (--weights) or cached compressed weights "
|
||||
|
|
@ -721,7 +722,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
|
||||
// Get weights, compress, and store in cache.
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
||||
LoadWeights<TConfig>(model);
|
||||
LoadWeights<TConfig>(weights_path);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, cache.path.c_str());
|
||||
|
|
@ -731,14 +732,17 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
|
||||
// Type-erased because this function is called via a function pointer.
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
||||
const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
switch (args.ModelType()) {
|
||||
gcpp::Model model, const Path& weights, const Path& compressed_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool);
|
||||
return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights,
|
||||
pool);
|
||||
case Model::GEMMA_7B:
|
||||
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool);
|
||||
return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights,
|
||||
pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType()));
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -799,8 +803,6 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
|
||||
// that takes LoaderArgs and creates a Gemma instance.
|
||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
const Model model_type = args.ModelType();
|
||||
model_training = args.ModelTraining();
|
||||
|
|
@ -808,8 +810,8 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
|||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
|
||||
std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
|
||||
auto compressed_weights =
|
||||
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool);
|
||||
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
||||
args.ModelType(), args.model, args.cache, pool);
|
||||
switch (model_type) {
|
||||
case Model::GEMMA_2B:
|
||||
impl_.reset(
|
||||
|
|
|
|||
46
gemma.h
46
gemma.h
|
|
@ -66,6 +66,9 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
|||
|
||||
// TODO: Incorporate this
|
||||
struct Runtime {
|
||||
// TODO: In the future we may fold ModelTraining into Model.
|
||||
// As we add more variations of model_type, the cartesian set becomes
|
||||
// unwieldy.
|
||||
Model model_type;
|
||||
ModelTraining model_training;
|
||||
size_t max_tokens;
|
||||
|
|
@ -126,7 +129,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
|
||||
Path tokenizer;
|
||||
Path model; // uncompressed weights OR
|
||||
Path cache; // compressed weights
|
||||
Path cache; // compressed weights (TODO: update name)
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -151,26 +154,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
|
|
@ -212,6 +195,27 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
struct LoaderArgs; // forward declaration
|
||||
void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model);
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
|
|
|
|||
Loading…
Reference in New Issue