Merge branch 'dev' into improve_ops_utility

This commit is contained in:
enum-class 2024-03-23 10:36:56 +08:00
commit d079c8f1ba
23 changed files with 362 additions and 187 deletions

View File

@ -72,4 +72,4 @@ jobs:
with: with:
path: ~/.cache/bazel path: ~/.cache/bazel
key: bazel-${{ runner.os }} key: bazel-${{ runner.os }}
- run: bazel build -c opt --cxxopt=-std=c++20 //... - run: bazel build --cxxopt=-std=c++20 //...

View File

@ -4,7 +4,9 @@
load("@rules_license//rules:license.bzl", "license") load("@rules_license//rules:license.bzl", "license")
package( package(
default_applicable_licenses = ["//:license"], default_applicable_licenses = [
"//:license", # Placeholder comment, do not modify
],
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
) )

View File

@ -40,6 +40,7 @@ set(SOURCES
compression/nuq-inl.h compression/nuq-inl.h
compression/sfp.h compression/sfp.h
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h
util/app.h util/app.h
util/args.h util/args.h
) )

View File

@ -118,8 +118,7 @@ jax / pytorch / keras for NN deployments.
### Gemma struct contains all the state of the inference engine - tokenizer, weights, and activations ### Gemma struct contains all the state of the inference engine - tokenizer, weights, and activations
`Gemma(...)` - constructor, creates a gemma model object, which is a wrapper `Gemma(...)` - constructor, creates a gemma model object.
around 3 things - the tokenizer object, weights, activations, and KV Cache.
In a standard LLM chat app, you'll probably use a Gemma object directly, in In a standard LLM chat app, you'll probably use a Gemma object directly, in
more exotic data processing or research applications, you might decompose more exotic data processing or research applications, you might decompose
@ -129,11 +128,13 @@ only using a Gemma object.
### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly) ### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
You pretty much only do things with the tokenizer, call `Encode()` to go from The Gemma object contains contains a pointer to a Tokenizer object. The main
string prompts to token id vectors, or `Decode()` to go from token id vector operations performed on the tokenizer are to load the tokenizer model from a
outputs from the model back to strings. file (usually `tokenizer.spm`), call `Encode()` to go from string prompts to
token id vectors, or `Decode()` to go from token id vector outputs from the
model back to strings.
### The main entrypoint for generation is `GenerateGemma()` ### `GenerateGemma()` is the entrypoint for token generation
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
activation values in `model` and 2) invoke StreamFunc - a lambda callback for activation values in `model` and 2) invoke StreamFunc - a lambda callback for
@ -150,7 +151,7 @@ constrained decoding type of use cases where you want to force the generation
to fit a grammar. If you're not doing this, you can send an empty lambda as a to fit a grammar. If you're not doing this, you can send an empty lambda as a
no-op which is what `run.cc` does. no-op which is what `run.cc` does.
### If you want to invoke the neural network forward function directly call the `Transformer()` function ### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
For high-level applications, you might only call `GenerateGemma()` and never For high-level applications, you might only call `GenerateGemma()` and never
interact directly with the neural network, but if you're doing something a bit interact directly with the neural network, but if you're doing something a bit

View File

@ -36,7 +36,12 @@ For production-oriented edge deployments we recommend standard deployment
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
([all model variations here](https://www.kaggle.com/models/google/gemma)). ([all model variations here](https://www.kaggle.com/models/google/gemma)).
Community contributions large and small are welcome. This project follows ## Contributing
Community contributions large and small are welcome. See
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
for additional notes contributing developers and [join the discord by following
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
[Google's Open Source Community [Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/). Guidelines](https://opensource.google.com/conduct/).

View File

@ -1,3 +1,4 @@
# Required for referencing bazel:com_google_sentencepiece.patch
package( package(
default_applicable_licenses = ["//:license"], default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"], default_visibility = ["//:__subpackages__"],

View File

@ -1,10 +1,12 @@
# Weight compression, I/O and analysis # Weight compression, I/O and analysis
package( package(
default_applicable_licenses = ["//:license"], default_applicable_licenses = [
"//:license", # Placeholder comment, do not modify
],
default_visibility = [ default_visibility = [
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", # Placeholder for internal visibility,
"//:__subpackages__", "//:__subpackages__", # Placeholder, do not modify
], ],
) )
@ -22,20 +24,36 @@ cc_library(
], ],
) )
# Deprecated because it is also implemented in Highway; will be removed once
# that Highway version is sufficiently widespread.
cc_library( cc_library(
name = "stats", name = "stats",
srcs = [ srcs = ["stats.cc"],
"stats.cc", hdrs = ["stats.h"],
],
hdrs = [
"distortion.h",
"stats.h",
],
deps = [ deps = [
"@hwy//:hwy", "@hwy//:hwy",
], ],
) )
cc_library(
name = "distortion",
hdrs = ["distortion.h"],
deps = [
"@hwy//:hwy",
],
)
cc_library(
name = "test_util",
hdrs = ["test_util.h"],
deps = [
":distortion",
":stats",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)
cc_library( cc_library(
name = "sfp", name = "sfp",
hdrs = [ hdrs = [
@ -60,12 +78,11 @@ cc_test(
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":sfp", ":sfp",
":stats", ":test_util",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",
"@hwy//:thread_pool",
], ],
) )
@ -96,7 +113,7 @@ cc_test(
deps = [ deps = [
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":test_util",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
@ -116,6 +133,7 @@ cc_library(
], ],
deps = [ deps = [
":blob_store", ":blob_store",
":distortion",
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
@ -132,6 +150,7 @@ cc_library(
"analyze.h", "analyze.h",
], ],
deps = [ deps = [
":distortion",
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",

View File

@ -42,6 +42,10 @@
#include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#ifndef HWY_IF_CONSTEXPR
#define HWY_IF_CONSTEXPR if
#endif
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
@ -124,7 +128,7 @@ class NuqClustering {
} }
private: private:
// Float has enough precision for our relatively small kGroupSize (128). // Float has enough precision for our relatively small kGroupSize (256).
float cumsum_[kGroupSize + 1]; float cumsum_[kGroupSize + 1];
float cumsum2_[kGroupSize + 1]; float cumsum2_[kGroupSize + 1];
float inv_len_[kGroupSize + 1]; float inv_len_[kGroupSize + 1];
@ -168,8 +172,8 @@ class NuqClustering {
// `centers`; prior centers are zero-initialized. // `centers`; prior centers are zero-initialized.
// //
// O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low // O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low
// that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK // that this is about 5 times as fast as the O(kClusters * kGroupSize) SMAWK
// as implemented in FAISS, for our kGroupSize <= 128. // as implemented in FAISS, for our kGroupSize of 256.
template <class DF> template <class DF>
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x, static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x,
ClusterBuf& buf, ClusterBuf& buf,
@ -228,7 +232,7 @@ class NuqClustering {
// Center = mean, O(1) thanks to cumulative sums. // Center = mean, O(1) thanks to cumulative sums.
const float sum = cc.SumOfSorted(start, last); const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1; const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize); HWY_DASSERT(0 < size && size <= static_cast<int>(kGroupSize));
centers[k] = sum / static_cast<float>(size); centers[k] = sum / static_cast<float>(size);
// We know the range inside sorted_and_i[]; translate to original indices, // We know the range inside sorted_and_i[]; translate to original indices,
@ -427,7 +431,7 @@ class NuqCodec {
// instead of TableLookupBytes, which requires extra interleaving of lo/hi. // instead of TableLookupBytes, which requires extra interleaving of lo/hi.
HWY_DASSERT(hn::Lanes(du) >= 8); HWY_DASSERT(hn::Lanes(du) >= 8);
if (NumTables(du) == 2) { HWY_IF_CONSTEXPR(NumTables(du) == 2) {
// Reduce cap for second half to avoid loading past the end of the table. // Reduce cap for second half to avoid loading past the end of the table.
const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2; const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2;
*tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2)); *tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2));
@ -449,11 +453,12 @@ class NuqCodec {
const auto indices0 = hn::IndicesFromVec(du, idx0); const auto indices0 = hn::IndicesFromVec(du, idx0);
const auto indices1 = hn::IndicesFromVec(du, idx1); const auto indices1 = hn::IndicesFromVec(du, idx1);
if (NumTables(du) == 1) { HWY_IF_CONSTEXPR(NumTables(du) == 1) {
(void)tbl1; (void)tbl1;
c0 = hn::TableLookupLanes(tbl0, indices0); c0 = hn::TableLookupLanes(tbl0, indices0);
c1 = hn::TableLookupLanes(tbl0, indices1); c1 = hn::TableLookupLanes(tbl0, indices1);
} else { }
HWY_IF_CONSTEXPR(NumTables(du) == 2) { // `else` is poorly formatted.
c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0);
c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1); c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1);
} }
@ -521,8 +526,8 @@ class NuqCodec {
// Decodes `num` values from the stream `in`, starting at the offset `in_ofs` // Decodes `num` values from the stream `in`, starting at the offset `in_ofs`
// (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num` // (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num`
// must all be multiples of `kGroupSize`. // must all be multiples of `kGroupSize`.
template <class DF, HWY_IF_BF16_D(DF)> template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec(DF dbf, const size_t in_capacity, static HWY_INLINE void Dec(DBF dbf, const size_t in_capacity,
const NuqStream* const in, const size_t in_ofs, const NuqStream* const in, const size_t in_ofs,
hwy::bfloat16_t* const out, const size_t num) { hwy::bfloat16_t* const out, const size_t num) {
const hn::RebindToUnsigned<decltype(dbf)> d16; const hn::RebindToUnsigned<decltype(dbf)> d16;

View File

@ -27,6 +27,7 @@
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/timer.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -35,15 +36,14 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h // Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/nuq-inl.h" #include "compression/nuq-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/nuq.h" #include "compression/nuq.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
#include "hwy/timer.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -181,12 +181,14 @@ struct TestNormal {
auto in = hwy::AllocateAligned<float>(kGroupSize); auto in = hwy::AllocateAligned<float>(kGroupSize);
HWY_ASSERT(in); HWY_ASSERT(in);
std::mt19937 rng(123); hwy::RandomState rng;
std::normal_distribution<float> dist{0.001f, 0.3f}; Stats in_stats;
for (size_t i = 0; i < kGroupSize; ++i) { for (size_t i = 0; i < kGroupSize; ++i) {
in[i] = dist(rng); const double r = RandomGaussian(rng);
in_stats.Notify(r);
in[i] = hwy::ConvertScalarTo<T>(r);
} }
std::shuffle(in.get(), in.get() + kGroupSize, rng); VerifyGaussian(in_stats);
ClusterBuf buf; ClusterBuf buf;
float centers[kClusters]; float centers[kClusters];
@ -212,9 +214,9 @@ struct TestNormal {
const float snr = stats.GeomeanValueDivL1(); const float snr = stats.GeomeanValueDivL1();
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
stats.MaxIndex(), stats.MaxL1()); stats.MaxIndex(), stats.MaxL1());
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); static_assert(kGroupSize == 256, "Update expected");
const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f; const float expected_pnorm = 3.68E-2f;
const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f; const float expected_snr = 12.7f;
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
} }
@ -345,21 +347,27 @@ struct TestDot {
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num)); auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
HWY_ASSERT(in && dec && vec && nuq); HWY_ASSERT(in && dec && vec && nuq);
std::mt19937 rng(123); // Generate inputs and verify their distribution.
std::normal_distribution<float> dist{0.001f, 0.3f}; hwy::RandomState rng;
Stats in_stats;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
in[i] = dist(rng); const float r = static_cast<float>(RandomGaussian(rng));
vec[i] = hwy::ConvertScalarTo<T>(dist(rng)); in_stats.Notify(r);
in[i] = r;
} }
// This changes the correlation between in and vec, which considerably for (size_t i = 0; i < num; ++i) {
// affects the error of the result. const float r = static_cast<float>(RandomGaussian(rng));
std::shuffle(in.get(), in.get() + num, rng); in_stats.Notify(r);
vec[i] = hwy::ConvertScalarTo<T>(r);
}
VerifyGaussian(in_stats);
ClusterBuf buf; ClusterBuf buf;
const size_t unused_clusters = const size_t unused_clusters =
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
HWY_ASSERT(unused_clusters == 0); HWY_ASSERT(unused_clusters == 0);
// Compute dot product without decompression.
double actual = 0.0; double actual = 0.0;
double elapsed = hwy::HighestValue<double>(); double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 20; ++rep) { for (size_t rep = 0; rep < 20; ++rep) {
@ -380,24 +388,39 @@ struct TestDot {
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T), fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(in[0]) * 1E-6 / elapsed); num * sizeof(in[0]) * 1E-6 / elapsed);
double expected = 0.0; // using original input // Exact and decompressed dot products for comparison.
double expected2 = 0.0; // using decoded NUQ double exact = 0.0; // using original input
double expected = 0.0; // using decoded NUQ
DistortionStats dec_stats;
Stats ratios;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
expected += in[i] * hwy::ConvertScalarTo<double>(vec[i]); dec_stats.Notify(in[i], dec[i]);
expected2 += dec[i] * hwy::ConvertScalarTo<double>(vec[i]); const float v1 = hwy::ConvertScalarTo<float>(vec[i]);
exact += in[i] * v1;
expected += dec[i] * v1;
if (expected != 0.0f) {
ratios.Notify(exact / expected);
}
} }
const double l1 = hwy::ScalarAbs(expected - actual); const double dec_snr = dec_stats.GeomeanValueDivL1();
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", // exact and actual fluctuate due to the combination of NUQ imprecision,
expected, expected2, actual, l1, snr); // and whether vec[i] is negative or positive, so this is quite loose.
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); const float final_ratio = HWY_MIN(exact / actual, actual / exact);
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2; fprintf(stderr,
const double expected_snr = kGroupSize == 128 ? 9.7f "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
: sizeof(T) == 2 ? 14.5f "dot_snr %.2f\n",
: 14.9f; exact, expected, actual, final_ratio, dec_snr, dot_snr);
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); // Final values are not too far apart.
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f);
// Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
// dec[] is close to in[], but we already check that in TestStream.
HWY_ASSERT(dec_snr >= 13.0);
// Geomean of ratios for each i is an approximation of the actual SNR.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0));
static_assert(kGroupSize == 256, "Update expected*");
} }
}; };
@ -429,6 +452,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16);
#ifdef HWY_AFTER_TEST
HWY_AFTER_TEST();
#endif
} // namespace gcpp } // namespace gcpp
#endif #endif

View File

@ -449,6 +449,18 @@ class SfpCodec {
return Enc2U(d16, w0, w1); return Enc2U(d16, w0, w1);
} }
// Truncates two f32 to bf16, in lane order, without rounding (see Enc4F).
template <class DBF, class DF = hn::RepartitionToWide<DBF>>
static HWY_INLINE hn::Vec<DBF> Truncate2To(DBF dbf, hn::Vec<DF> f0,
hn::Vec<DF> f1) {
const hn::RebindToUnsigned<DBF> d16;
using V16 = hn::Vec<decltype(d16)>;
const V16 u0 = BitCast(d16, f0);
const V16 u1 = BitCast(d16, f1);
return BitCast(DBF(), HWY_IS_LITTLE_ENDIAN ? ConcatOdd(d16, u1, u0)
: ConcatEven(d16, u1, u0));
}
template <class DF, HWY_IF_F32_D(DF), template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>> class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) { static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
@ -462,9 +474,10 @@ class SfpCodec {
const VF f1 = hn::LoadU(df, in + NF * 1); const VF f1 = hn::LoadU(df, in + NF * 1);
const VF f2 = hn::LoadU(df, in + NF * 2); const VF f2 = hn::LoadU(df, in + NF * 2);
const VF f3 = hn::LoadU(df, in + NF * 3); const VF f3 = hn::LoadU(df, in + NF * 3);
// Chop off the lower 16 bits; EncBytes still rounds properly. // Chop off the lower 16 bits instead of OrderedDemote2To, which rounds to
const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1)); // the nearest bf16, because EncBytes will round again.
const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3)); const V16 w0 = hn::BitCast(d16, Truncate2To(dbf, f0, f1));
const V16 w1 = hn::BitCast(d16, Truncate2To(dbf, f2, f3));
return Enc2U(d16, w0, w1); return Enc2U(d16, w0, w1);
} }

View File

@ -25,12 +25,11 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm>
#include <random>
#include <set> #include <set>
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/timer.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -39,13 +38,12 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h // Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/sfp-inl.h" #include "compression/sfp-inl.h"
// copybara:import_next_line:gemma_cpp
#include "compression/test_util.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
#include "hwy/timer.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -358,25 +356,31 @@ struct TestDot {
template <typename T, class D> template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) { HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df; const hn::Repartition<float, D> df;
const size_t num = 384; const size_t num = 1024; // not too many for GeometricMean overflow.
auto in = hwy::AllocateAligned<T>(num); auto in = hwy::AllocateAligned<T>(num);
auto dec = hwy::AllocateAligned<T>(num); auto dec = hwy::AllocateAligned<T>(num);
auto vec = hwy::AllocateAligned<T>(num); auto vec = hwy::AllocateAligned<T>(num);
auto sfp = hwy::AllocateAligned<SfpStream>(num); auto sfp = hwy::AllocateAligned<SfpStream>(num);
HWY_ASSERT(in && dec && vec && sfp); HWY_ASSERT(in && dec && vec && sfp);
std::mt19937 rng(123); // Generate inputs and verify their distribution.
std::normal_distribution<float> dist{0.001f, 0.3f}; hwy::RandomState rng;
Stats in_stats;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
in[i] = hwy::ConvertScalarTo<T>(dist(rng)); const float r = static_cast<float>(RandomGaussian(rng));
vec[i] = hwy::ConvertScalarTo<T>(dist(rng)); in_stats.Notify(r);
in[i] = hwy::ConvertScalarTo<T>(r);
} }
// This changes the correlation between in and vec, which considerably for (size_t i = 0; i < num; ++i) {
// affects the error of the result. const float r = static_cast<float>(RandomGaussian(rng));
std::shuffle(in.get(), in.get() + num, rng); in_stats.Notify(r);
vec[i] = hwy::ConvertScalarTo<T>(r);
}
VerifyGaussian(in_stats);
SfpCodec::Enc(d, in.get(), num, sfp.get()); SfpCodec::Enc(d, in.get(), num, sfp.get());
// Compute dot product without decompression.
double actual = 0.0; double actual = 0.0;
double elapsed = hwy::HighestValue<double>(); double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 200; ++rep) { for (size_t rep = 0; rep < 200; ++rep) {
@ -393,26 +397,44 @@ struct TestDot {
} }
SfpCodec::Dec(d, sfp.get(), num, dec.get()); SfpCodec::Dec(d, sfp.get(), num, dec.get());
fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T), fprintf(stderr, "Vec %zu Dot %zu-bit %.2f MB/s\n", Lanes(d) * sizeof(T),
num * sizeof(T) * 1E-6 / elapsed); sizeof(T) * 8, num * sizeof(T) * 1E-6 / elapsed);
double expected = 0.0; // using original input // Exact and decompressed dot products for comparison.
double expected2 = 0.0; // using decoded SFP float exact = 0.0f; // using original input
float expected = 0.0f; // using decoded SFP
DistortionStats dec_stats;
Stats ratios;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
expected += hwy::ConvertScalarTo<double>(in[i]) * const float in1 = hwy::ConvertScalarTo<float>(in[i]);
hwy::ConvertScalarTo<double>(vec[i]); const float dec1 = hwy::ConvertScalarTo<float>(dec[i]);
expected2 += hwy::ConvertScalarTo<double>(dec[i]) * const float vec1 = hwy::ConvertScalarTo<float>(vec[i]);
hwy::ConvertScalarTo<double>(vec[i]); dec_stats.Notify(in1, dec1);
exact += in1 * vec1;
expected += dec1 * vec1;
if (expected != 0.0f) {
ratios.Notify(exact / expected);
}
} }
const double l1 = hwy::ScalarAbs(expected - actual); const double dec_snr = dec_stats.GeomeanValueDivL1();
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", // exact and actual fluctuate due to the combination of SFP imprecision,
expected, expected2, actual, l1, snr); // and whether vec[i] is negative or positive, so this is quite loose.
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); const float final_ratio = HWY_MIN(exact / actual, actual / exact);
const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2; fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f; fprintf(stderr,
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); "dot_snr %.2f\n",
exact, expected, actual, final_ratio, dec_snr, dot_snr);
// Final values are not too far apart.
HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f);
// Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
// dec[] is close to in[], but we already check that in TestEncDec.
HWY_ASSERT(dec_snr >= 50.0);
// Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0));
} }
}; };
@ -441,6 +463,9 @@ HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16);
#ifdef HWY_AFTER_TEST
HWY_AFTER_TEST();
#endif
} // namespace gcpp } // namespace gcpp
#endif #endif

64
compression/test_util.h Normal file
View File

@ -0,0 +1,64 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
#include <stddef.h>
#include <stdint.h>
#include <cmath>
#include "hwy/base.h"
// IWYU pragma: begin_exports
// copybara:import_next_line:gemma_cpp
#include "compression/distortion.h"
// copybara:import_next_line:gemma_cpp
#include "compression/stats.h"
#include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports
namespace gcpp {
// Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights)
// using the central limit theorem. Avoid std::normal_distribution for
// consistent cross-platform output.
HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
uint64_t sum = 0;
constexpr int kReps = 40;
for (int rep = 0; rep < kReps; ++rep) {
sum += hwy::Random32(&rng) & 0xFFFFF;
}
const double sum_f =
static_cast<double>(sum) / static_cast<double>(0xFFFFF * kReps);
HWY_ASSERT(0.0 <= sum_f && sum_f <= 1.0);
const double plus_minus_1 = 2.0 * sum_f - 1.0;
HWY_ASSERT(-1.0 <= plus_minus_1 && plus_minus_1 <= 1.0);
// Normalize by stddev of sum of uniform random scaled to [-1, 1].
return plus_minus_1 * std::sqrt(kReps / 3.0);
};
HWY_INLINE void VerifyGaussian(Stats& stats) {
const double stddev = stats.StandardDeviation();
HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01);
HWY_ASSERT(0.30 <= stddev && stddev <= 0.35);
HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9);
HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_

View File

@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK;
struct ConfigGemma7B { struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 28; static constexpr int kLayers = 28;
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
@ -49,12 +49,12 @@ struct ConfigGemma7B {
struct ConfigGemma2B { struct ConfigGemma2B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 18; static constexpr int kLayers = 18;
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
}; };

7
examples/README.md Normal file
View File

@ -0,0 +1,7 @@
# Examples
In this directory are some simple examples illustrating usage of `gemma.cpp` as
a library beyond the interactive `gemma` app implemented in `run.cc`.
- `hello_world/` - minimal/template project for using `gemma.cpp` as a library.
It sets up the model state and generates text for a single hard coded prompt.

View File

@ -17,13 +17,10 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
// copybara:end // copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs
// copybara:end
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
std::vector<int> tokenize( std::vector<int> tokenize(

0
experimental/.gitkeep Normal file
View File

3
experimental/README.md Normal file
View File

@ -0,0 +1,3 @@
# Experimental
This directory is for experimental code and features.

View File

@ -25,6 +25,8 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -50,6 +52,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
@ -68,12 +72,13 @@ template <class TConfig>
struct Layer { struct Layer {
Layer() = default; Layer() = default;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
// 3x for (query, key, value) static constexpr size_t kQKVEinsumWSize =
static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim; (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector) // 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
@ -311,47 +316,47 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static constexpr size_t kModelDim = static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static const float kQueryScale = static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { const size_t batch_offset = batch_idx * kModelDim;
// linear projections to QKV
const size_t head_offset =
3 * kQKVDim * kModelDim; // 3x for QKV dimensions
const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
const size_t batch_offset = batch_idx * kModelDim;
MatVecLoop<kQKVDim, kModelDim>( MatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, q_offset, c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim,
activations.pre_att_rms_out.data() + batch_offset, q); activations.pre_att_rms_out.data() + batch_offset, q);
};
const size_t kv_offset = auto ProjKV =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
TwoOfsMatVecLoop<kQKVDim, kModelDim>( Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
c_layer->c_qkv_einsum_w, k_offset, v_offset, };
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
// Calculate scores // Calculate scores
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT head_att = activations.att.data() + float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen + head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
Rope(q, kQKVDim, pos); Rope(q, kQKVDim, pos);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2] = score;
@ -364,7 +369,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
} }
@ -377,7 +382,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w, MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out, head * kModelDim * kQKVDim, att_out,
head_out); head_out);
}); };
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
ProjQ(head, head_offset);
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset);
Attn(head, head * kQKVDim);
});
} else {
// Multi-Query Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
});
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
Attn(head, 0);
});
}
// accumulate output across all heads into att_post2. head 0 already wrote // accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2. // directly to att_post2.
@ -813,8 +852,9 @@ void GemmaImpl<ConfigGemma7B>::Generate(
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, const Path& weights_path, Model model_type, ModelTraining training,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool)
: model_training(training) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{ {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
@ -839,11 +879,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
} }
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool)
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
pool) {}
Gemma::~Gemma() = default; // after GemmaInterface is defined Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {

16
gemma.h
View File

@ -16,29 +16,20 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#include <algorithm>
#include <cctype>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random> #include <random>
#include <string>
#include <vector> #include <vector>
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream #include "compression/compress.h" // SfpStream/NuqStream
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen #include "util/args.h" // Path
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
// copybara:end
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
// copybara:end
namespace gcpp { namespace gcpp {
@ -75,9 +66,8 @@ struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, hwy::ThreadPool& pool); const Path& weights_path, Model model_type, ModelTraining training,
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, hwy::ThreadPool& pool);
Model model_type, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;

15
ops.h
View File

@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
const float* HWY_RESTRICT a, size_t size) { const float* HWY_RESTRICT a, size_t size) {
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
using V = hn::Vec<decltype(d)>;
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size >= 2 * N);
HWY_DASSERT(size % (2 * N) == 0); HWY_DASSERT(size % (2 * N) == 0);
auto sum0 = hn::Zero(d); V sum0 = hn::Zero(d);
auto sum1 = hn::Zero(d); V sum1 = hn::Zero(d);
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
const auto a0 = LoadU(d, a + i); const V a0 = hn::LoadU(d, a + i);
sum0 = MulAdd(a0, a0, sum0); sum0 = hn::MulAdd(a0, a0, sum0);
const auto a1 = LoadU(d, a + i + N); const V a1 = hn::LoadU(d, a + i + N);
sum1 = MulAdd(a1, a1, sum1); sum1 = hn::MulAdd(a1, a1, sum1);
} }
return ReduceSum(d, Add(sum0, sum1)); return hn::ReduceSum(d, hn::Add(sum0, sum1));
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(

11
run.cc
View File

@ -22,18 +22,15 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma #include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/app.h" #include "util/app.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
// copybara:end
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -234,8 +231,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
} }
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
loader.ModelType(), pool); loader.ModelType(), loader.ModelTraining(), pool);
auto kv_cache = CreateKVCache(loader.ModelType()); auto kv_cache = CreateKVCache(loader.ModelType());
@ -277,6 +274,8 @@ int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.misc"); PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv); gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv); gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv); gcpp::AppArgs app(argc, argv);

View File

@ -34,15 +34,10 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "configs.h" #include "configs.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
// copybara:end
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
namespace gcpp { namespace gcpp {

View File

@ -72,26 +72,12 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
def expand_qkv(qkv_proj: np.array) -> np.array:
"""This won't be needed anymore when MQA is implemented"""
assert qkv_proj.shape == (2560, 2048)
qkv = qkv_proj.reshape((10, 256, 2048))
q_proj = qkv[:8].reshape((1,8,256,2048))
kv_proj = qkv[8:]
kv_proj = kv_proj[:, np.newaxis, :, :]
kv_proj = np.repeat(kv_proj, 8, axis=1)
qkv = np.concatenate([q_proj, kv_proj])
qkv = np.transpose(qkv, axes=[1,0,2,3])
return qkv
TRANSFORMATIONS = { TRANSFORMATIONS = {
"2b":defaultdict( "2b":defaultdict(
lambda: lambda x: x, lambda: lambda x: x,
{ {
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": expand_qkv, "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
@ -101,7 +87,7 @@ TRANSFORMATIONS = {
"7b":defaultdict( "7b":defaultdict(
lambda: lambda x: x, lambda: lambda x: x,
{ {
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), "embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
@ -113,9 +99,9 @@ TRANSFORMATIONS = {
VALIDATIONS = { VALIDATIONS = {
"2b": { "2b": {
"embedder.weight": lambda x: x.shape == (256128, 2048), "embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,), "model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
@ -124,7 +110,7 @@ VALIDATIONS = {
"post_attention_layernorm.weight": lambda x: x.shape == (2048,), "post_attention_layernorm.weight": lambda x: x.shape == (2048,),
}, },
"7b": { "7b": {
"embedder.weight": lambda x: x.shape == (256128, 3072), "embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,), "model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),