Fix msan error, uninitialized model_training

This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion.

Also fix includes.

PiperOrigin-RevId: 617386447
This commit is contained in:
Jan Wassenberg 2024-03-20 05:12:06 +01:00
parent 6865819bb7
commit 11d9c51473
3 changed files with 9 additions and 8 deletions

View File

@ -25,6 +25,8 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -813,8 +815,9 @@ void GemmaImpl<ConfigGemma7B>::Generate(
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, const Path& weights_path, Model model_type, ModelTraining training,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool)
: model_training(training) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{ {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");

View File

@ -16,12 +16,9 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#include <algorithm>
#include <cctype>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random> #include <random>
#include <string>
#include <vector> #include <vector>
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
@ -31,7 +28,7 @@
#include "configs.h" // kSeqLen #include "configs.h" // kSeqLen
// copybara:end // copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase #include "util/args.h" // Path
// copybara:end // copybara:end
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
@ -75,7 +72,8 @@ struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, hwy::ThreadPool& pool); const Path& weights_path, Model model_type, ModelTraining training,
hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;

2
run.cc
View File

@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
loader.ModelType(), pool); loader.ModelType(), loader.ModelTraining(), pool);
auto kv_cache = CreateKVCache(loader.ModelType()); auto kv_cache = CreateKVCache(loader.ModelType());