mirror of https://github.com/google/gemma.cpp.git
parent
30b8a3c1ac
commit
89be4c3de8
|
|
@ -72,4 +72,4 @@ jobs:
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/bazel
|
path: ~/.cache/bazel
|
||||||
key: bazel-${{ runner.os }}
|
key: bazel-${{ runner.os }}
|
||||||
- run: bazel build --cxxopt=-std=c++20 //...
|
- run: bazel build -c opt --cxxopt=-std=c++20 //...
|
||||||
|
|
@ -4,9 +4,7 @@
|
||||||
load("@rules_license//rules:license.bzl", "license")
|
load("@rules_license//rules:license.bzl", "license")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_applicable_licenses = [
|
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
|
||||||
"//:license", # Placeholder comment, do not modify
|
|
||||||
],
|
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# Required for referencing bazel:com_google_sentencepiece.patch
|
|
||||||
package(
|
package(
|
||||||
default_applicable_licenses = ["//:license"],
|
default_applicable_licenses = ["//:license"],
|
||||||
default_visibility = ["//:__subpackages__"],
|
default_visibility = ["//:__subpackages__"],
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
# Weight compression, I/O and analysis
|
# Weight compression, I/O and analysis
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_applicable_licenses = [
|
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
|
||||||
"//:license", # Placeholder comment, do not modify
|
|
||||||
],
|
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
# Placeholder for internal visibility,
|
# Placeholder for internal visibility,
|
||||||
"//:__subpackages__", # Placeholder, do not modify
|
"//third_party/gemma_cpp:__subpackages__",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,13 @@
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h"
|
#include "gemma.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:end
|
||||||
#include "util/app.h" // LoaderArgs
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
// copybara:end
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/app.h" // LoaderArgs
|
||||||
|
// copybara:end
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
std::vector<int> tokenize(
|
std::vector<int> tokenize(
|
||||||
|
|
|
||||||
15
gemma.cc
15
gemma.cc
|
|
@ -25,8 +25,6 @@
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -52,7 +50,10 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// copybara:strip_begin
|
||||||
|
// Required because sentencepiece uses Google I/O which requires InitGoogle.
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
|
// copybara:strip_end
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
|
@ -817,9 +818,8 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
const Path& weights_path, Model model_type, ModelTraining training,
|
const Path& weights_path, Model model_type,
|
||||||
hwy::ThreadPool& pool)
|
hwy::ThreadPool& pool) {
|
||||||
: model_training(training) {
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
|
@ -844,6 +844,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
|
Model model_type, hwy::ThreadPool& pool)
|
||||||
|
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
|
||||||
|
pool) {}
|
||||||
|
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||||
|
|
|
||||||
16
gemma.h
16
gemma.h
|
|
@ -16,20 +16,29 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress.h" // SfpStream/NuqStream
|
#include "compression/compress.h" // SfpStream/NuqStream
|
||||||
|
// copybara:end
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // Path
|
#include "configs.h" // kSeqLen
|
||||||
|
// copybara:end
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // ArgsBase
|
||||||
|
// copybara:end
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -66,8 +75,9 @@ struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
const Path& weights_path, Model model_type, ModelTraining training,
|
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
||||||
hwy::ThreadPool& pool);
|
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
|
Model model_type, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
|
|
|
||||||
15
ops.h
15
ops.h
|
|
@ -341,21 +341,20 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
const float* HWY_RESTRICT a, size_t size) {
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
using V = hn::Vec<decltype(d)>;
|
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
HWY_DASSERT(size >= 2 * N);
|
HWY_DASSERT(size >= 2 * N);
|
||||||
HWY_DASSERT(size % (2 * N) == 0);
|
HWY_DASSERT(size % (2 * N) == 0);
|
||||||
|
|
||||||
V sum0 = hn::Zero(d);
|
auto sum0 = hn::Zero(d);
|
||||||
V sum1 = hn::Zero(d);
|
auto sum1 = hn::Zero(d);
|
||||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||||
const V a0 = hn::LoadU(d, a + i);
|
const auto a0 = LoadU(d, a + i);
|
||||||
sum0 = hn::MulAdd(a0, a0, sum0);
|
sum0 = MulAdd(a0, a0, sum0);
|
||||||
const V a1 = hn::LoadU(d, a + i + N);
|
const auto a1 = LoadU(d, a + i + N);
|
||||||
sum1 = hn::MulAdd(a1, a1, sum1);
|
sum1 = MulAdd(a1, a1, sum1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
return ReduceSum(d, Add(sum0, sum1));
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
|
||||||
12
run.cc
12
run.cc
|
|
@ -22,15 +22,19 @@
|
||||||
#include <thread> // NOLINT
|
#include <thread> // NOLINT
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify. // copybara:strip
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
// copybara:end
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h" // Gemma
|
#include "gemma.h" // Gemma
|
||||||
|
// copybara:end
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
|
// copybara:end
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
|
// copybara:end
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -231,8 +235,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
|
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||||
loader.ModelType(), loader.ModelTraining(), pool);
|
loader.ModelType(), pool);
|
||||||
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
|
@ -274,7 +278,9 @@ int main(int argc, char** argv) {
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.misc");
|
PROFILER_ZONE("Startup.misc");
|
||||||
|
|
||||||
|
// copybara:strip_begin
|
||||||
// Placeholder for internal init, do not modify.
|
// Placeholder for internal init, do not modify.
|
||||||
|
// copybara:strip_end
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
|
|
|
||||||
|
|
@ -34,10 +34,15 @@
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "configs.h"
|
#include "configs.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h"
|
#include "gemma.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
// copybara:end
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue