[WIP] remove args from GetWeights, GetCompressedWeights

This commit is contained in:
austinvhuang 2024-03-08 00:00:11 -05:00
parent 3df06f64c2
commit b67e28d1a0
3 changed files with 40 additions and 33 deletions

View File

@ -30,6 +30,7 @@
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#include "util/app.h" // arg types
#include "util/args.h" // Path
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
@ -697,10 +698,10 @@ void ForEachTensor(const Weights<TConfig>* weights,
template <class TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
const Path& model, const Path& cache, hwy::ThreadPool& pool) {
const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadCache");
if (!std::filesystem::exists(model.path) &&
if (!std::filesystem::exists(weights_path.path) &&
!std::filesystem::exists(cache.path)) {
HWY_ABORT(
"Either the model weights (--weights) or cached compressed weights "
@ -721,7 +722,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
// Get weights, compress, and store in cache.
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
LoadWeights<TConfig>(model);
LoadWeights<TConfig>(weights_path);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str());
@ -731,14 +732,17 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
// Type-erased because this function is called via a function pointer.
hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
const LoaderArgs& args, hwy::ThreadPool& pool) {
switch (args.ModelType()) {
gcpp::Model model, const Path& weights, const Path& compressed_weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return GetCompressedWeights<ConfigGemma2B>(args.model, args.cache, pool);
return GetCompressedWeights<ConfigGemma2B>(weights, compressed_weights,
pool);
case Model::GEMMA_7B:
return GetCompressedWeights<ConfigGemma7B>(args.model, args.cache, pool);
return GetCompressedWeights<ConfigGemma7B>(weights, compressed_weights,
pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(args.ModelType()));
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
@ -799,8 +803,6 @@ void GemmaImpl<ConfigGemma7B>::Generate(
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
// that takes LoaderArgs and creates a Gemma instance.
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
const Model model_type = args.ModelType();
model_training = args.ModelTraining();
@ -808,8 +810,8 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
std::make_unique<sentencepiece::SentencePieceProcessor>();
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
auto compressed_weights =
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool);
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
args.ModelType(), args.model, args.cache, pool);
switch (model_type) {
case Model::GEMMA_2B:
impl_.reset(

46
gemma.h
View File

@ -66,6 +66,9 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// TODO: Incorporate this
struct Runtime {
// TODO: In the future we may fold ModelTraining into Model.
// As we add more variations of model_type, the cartesian set becomes
// unwieldy.
Model model_type;
ModelTraining model_training;
size_t max_tokens;
@ -126,7 +129,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path tokenizer;
Path model; // uncompressed weights OR
Path cache; // compressed weights
Path cache; // compressed weights (TODO: update name)
std::string model_type;
template <class Visitor>
@ -151,26 +154,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}
};
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
@ -212,6 +195,27 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
}
};
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
struct LoaderArgs; // forward declaration
void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model);
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,

1
run.cc
View File

@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
gcpp::Gemma model(loader, pool);
auto kv_cache = CreateKVCache(loader.ModelType());
if (const char* error = inference.Validate()) {