mirror of https://github.com/google/gemma.cpp.git
parent
5e0cafbdc2
commit
fdc3812446
12
gemma.cc
12
gemma.cc
|
|
@ -25,8 +25,6 @@
|
|||
#include "compression/compress-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -815,9 +813,8 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type, ModelTraining training,
|
||||
hwy::ThreadPool& pool)
|
||||
: model_training(training) {
|
||||
const Path& weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
{
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
|
|
@ -842,6 +839,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
|||
}
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool)
|
||||
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
|
||||
pool) {}
|
||||
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
|
|
|
|||
10
gemma.h
10
gemma.h
|
|
@ -16,9 +16,12 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
@ -28,7 +31,7 @@
|
|||
#include "configs.h" // kSeqLen
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "util/args.h" // ArgsBase
|
||||
// copybara:end
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
|
@ -72,8 +75,9 @@ struct GemmaInterface;
|
|||
|
||||
struct Gemma {
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type, ModelTraining training,
|
||||
hwy::ThreadPool& pool);
|
||||
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
|
|
|
|||
4
run.cc
4
run.cc
|
|
@ -234,8 +234,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
|
||||
loader.ModelType(), loader.ModelTraining(), pool);
|
||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||
loader.ModelType(), pool);
|
||||
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue