diff --git a/gemma.cc b/gemma.cc index ba6aafa..f080dda 100644 --- a/gemma.cc +++ b/gemma.cc @@ -30,6 +30,7 @@ #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" +#include "util/app.h" // arg types #include "util/args.h" // Path // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last @@ -697,10 +698,10 @@ void ForEachTensor(const Weights* weights, template hwy::AlignedFreeUniquePtr GetCompressedWeights( - const Path& model, const Path& cache, hwy::ThreadPool& pool) { + const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.LoadCache"); - if (!std::filesystem::exists(model.path) && + if (!std::filesystem::exists(weights_path.path) && !std::filesystem::exists(cache.path)) { HWY_ABORT( "Either the model weights (--weights) or cached compressed weights " @@ -721,7 +722,7 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Get weights, compress, and store in cache. const hwy::AlignedUniquePtr> weights = - LoadWeights(model); + LoadWeights(weights_path); Compressor compressor(pool); ForEachTensor(weights.get(), *c_weights, compressor); compressor.WriteAll(pool, cache.path.c_str()); @@ -731,14 +732,17 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Type-erased because this function is called via a function pointer. hwy::AlignedFreeUniquePtr GetCompressedWeightsT( - const LoaderArgs& args, hwy::ThreadPool& pool) { - switch (args.ModelType()) { + gcpp::Model model, const Path& weights, const Path& compressed_weights, + hwy::ThreadPool& pool) { + switch (model) { case Model::GEMMA_2B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); case Model::GEMMA_7B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); default: - HWY_ABORT("Model type %d unknown.", static_cast(args.ModelType())); + HWY_ABORT("Model type %d unknown.", static_cast(model)); } } @@ -799,8 +803,6 @@ void GemmaImpl::Generate( kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -// TODO: Make Gemma type independent of LoaderArgs, create a factory function -// that takes LoaderArgs and creates a Gemma instance. Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); @@ -808,8 +810,8 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { std::unique_ptr tokenizer = std::make_unique(); HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); - auto compressed_weights = - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool); + auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( + args.ModelType(), args.model, args.cache, pool); switch (model_type) { case Model::GEMMA_2B: impl_.reset( diff --git a/gemma.h b/gemma.h index 58fd74a..1a6ca07 100644 --- a/gemma.h +++ b/gemma.h @@ -66,6 +66,9 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT }; // TODO: Incorporate this struct Runtime { + // TODO: In the future we may fold ModelTraining into Model. + // As we add more variations of model_type, the cartesian set becomes + // unwieldy. Model model_type; ModelTraining model_training; size_t max_tokens; @@ -126,7 +129,7 @@ struct LoaderArgs : public ArgsBase { Path tokenizer; Path model; // uncompressed weights OR - Path cache; // compressed weights + Path cache; // compressed weights (TODO: update name) std::string model_type; template @@ -151,26 +154,6 @@ struct LoaderArgs : public ArgsBase { } }; -struct GemmaInterface; - -struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - - const sentencepiece::SentencePieceProcessor* Tokenizer() const; - - std::unique_ptr impl_; - gcpp::ModelTraining model_training; -}; - -KVCache CreateKVCache(Model type); // convenient workaround for now -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); - -// StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. -using StreamFunc = std::function; -using AcceptFunc = std::function; - struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -212,6 +195,27 @@ struct InferenceArgs : public ArgsBase { } }; +struct GemmaInterface; + +struct Gemma { + Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + const sentencepiece::SentencePieceProcessor* Tokenizer() const; + std::unique_ptr impl_; + gcpp::ModelTraining model_training; +}; + +struct LoaderArgs; // forward declaration +void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model); + +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. +using StreamFunc = std::function; +using AcceptFunc = std::function; + void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, diff --git a/run.cc b/run.cc index 40be63e..64b6399 100644 --- a/run.cc +++ b/run.cc @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) {