mirror of https://github.com/google/gemma.cpp.git
Fix msan error, uninitialized model_training
This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion. Also fix includes. PiperOrigin-RevId: 617386447
This commit is contained in:
parent
e2a04b79ed
commit
edaafe335f
7
gemma.cc
7
gemma.cc
|
|
@ -25,6 +25,8 @@
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // Path
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -813,8 +815,9 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
const Path& weights_path, Model model_type,
|
const Path& weights_path, Model model_type, ModelTraining training,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool)
|
||||||
|
: model_training(training) {
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
|
|
||||||
8
gemma.h
8
gemma.h
|
|
@ -16,12 +16,9 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cctype>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
|
@ -31,7 +28,7 @@
|
||||||
#include "configs.h" // kSeqLen
|
#include "configs.h" // kSeqLen
|
||||||
// copybara:end
|
// copybara:end
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // ArgsBase
|
#include "util/args.h" // Path
|
||||||
// copybara:end
|
// copybara:end
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
@ -75,7 +72,8 @@ struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
const Path& weights_path, Model model_type, ModelTraining training,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
|
|
|
||||||
2
run.cc
2
run.cc
|
|
@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
|
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
|
||||||
loader.ModelType(), pool);
|
loader.ModelType(), loader.ModelTraining(), pool);
|
||||||
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue