Merge branch 'dev' into feature-prompt-flag

This commit is contained in:
Prajwal Choudhari 2025-04-17 18:53:28 +05:30 committed by GitHub
commit 09dfb144c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 188 additions and 94 deletions

View File

@ -82,7 +82,7 @@ jobs:
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"]) subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
subprocess.run(["chmod", "700", "/kaggle/working/gemma"]) subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"]) subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
assert("write an email to the moon." not in output.lower()); assert("write an email to the moon." not in output.lower());
assert("moon" in output.lower()); assert("moon" in output.lower());
EOF EOF

View File

@ -92,6 +92,8 @@ cc_library(
":basics", ":basics",
":threading", ":threading",
":topology", ":topology",
"@highway//:hwy",
"@highway//:profiler",
], ],
) )
@ -180,6 +182,7 @@ cc_library(
"//compression:shared", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )
@ -664,6 +667,7 @@ cc_test(
":mat", ":mat",
":prompt", ":prompt",
":sampler", ":sampler",
":threading_context",
":weights", ":weights",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"@highway//:thread_pool", "@highway//:thread_pool",

View File

@ -70,6 +70,7 @@ cc_library(
hdrs = ["blob_store.h"], hdrs = ["blob_store.h"],
deps = [ deps = [
":io", ":io",
"//:basics",
"//:threading_context", "//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool", "@highway//:thread_pool",
@ -130,7 +131,6 @@ cc_library(
textual_hdrs = ["sfp-inl.h"], textual_hdrs = ["sfp-inl.h"],
deps = [ deps = [
":shared", ":shared",
"//:basics",
"@highway//:hwy", "@highway//:hwy",
], ],
) )
@ -195,7 +195,6 @@ cc_test(
deps = [ deps = [
":distortion", ":distortion",
":nuq", ":nuq",
":sfp",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:test_util", "//:test_util",
"@highway//:hwy", "@highway//:hwy",
@ -225,6 +224,7 @@ cc_library(
"//:mat", "//:mat",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats", "@highway//:stats",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
@ -259,6 +259,7 @@ cc_library(
deps = [ deps = [
":nuq", ":nuq",
":sfp", ":sfp",
":shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:stats", "@highway//:stats",
"@highway//:thread_pool", "@highway//:thread_pool",

View File

@ -21,8 +21,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <cmath> // lroundf, only if COMPRESS_STATS #include <memory>
#include <string>
#include <vector> #include <vector>
#include "compression/blob_store.h" #include "compression/blob_store.h"
@ -35,6 +34,10 @@
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h" #include "hwy/timer.h"
#if COMPRESS_STATS
#include <cmath> // lroundf
#endif
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.
@ -388,7 +391,7 @@ struct CompressTraits<SfpStream> {
const size_t packed_ofs) { const size_t packed_ofs) {
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs); SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
if (COMPRESS_STATS) { if constexpr (COMPRESS_STATS) {
const hn::Repartition<BF16, DF> dbf; const hn::Repartition<BF16, DF> dbf;
auto distorted = auto distorted =
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf))); hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
@ -432,9 +435,10 @@ struct CompressTraits<NuqStream> {
size_t num, CompressPerThread& tls, size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed, const PackedSpan<Packed>& packed,
const size_t packed_ofs) { const size_t packed_ofs) {
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs); if (!tls.buf) tls.buf = std::make_unique<NuqStream::ClusterBuf>();
NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs);
if (COMPRESS_STATS) { if constexpr (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f))); tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
} }
@ -478,7 +482,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
const size_t packed_ofs, hwy::ThreadPool& pool) { const size_t packed_ofs, hwy::ThreadPool& pool) {
packed.BoundsCheck(packed_ofs, num); packed.BoundsCheck(packed_ofs, num);
work.tls.resize(pool.NumWorkers()); work.tls.resize(pool.NumWorkers());
if (COMPRESS_STATS) { if constexpr (COMPRESS_STATS) {
for (auto& tls : work.tls) { for (auto& tls : work.tls) {
tls.stats.Reset(); tls.stats.Reset();
} }
@ -487,7 +491,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
const bool want_bench = COMPRESS_STATS || !kIsTest; const bool want_bench = COMPRESS_STATS || !kIsTest;
const double t0 = want_bench ? hwy::platform::Now() : 0.0; const double t0 = want_bench ? hwy::platform::Now() : 0.0;
using Traits = CompressTraits<Packed>; using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
constexpr size_t kBatch = 8192; constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch); const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(0, num_batches, pool.Run(0, num_batches,
@ -508,7 +512,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
fprintf(stderr, "Compress %.1f MB/s\n", mbps); fprintf(stderr, "Compress %.1f MB/s\n", mbps);
} }
if (COMPRESS_STATS) { if constexpr (COMPRESS_STATS) {
for (size_t i = 1; i < work.tls.size(); ++i) { for (size_t i = 1; i < work.tls.size(); ++i) {
work.tls[0].stats.Assimilate(work.tls[i].stats); work.tls[0].stats.Assimilate(work.tls[i].stats);
} }
@ -534,7 +538,7 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
const size_t packed_ofs) { const size_t packed_ofs) {
static_assert(hwy::IsSameEither<Packed, float, BF16>()); static_assert(hwy::IsSameEither<Packed, float, BF16>());
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df)); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
using Traits = CompressTraits<Packed>; using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
Traits::Store2(df, raw0, raw1, packed, packed_ofs); Traits::Store2(df, raw0, raw1, packed, packed_ofs);
} }

View File

@ -15,8 +15,34 @@
#include "compression/compress.h" #include "compression/compress.h"
#include <stddef.h>
#include <stdint.h>
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
// TODO: move ScaleWeights here. float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
PROFILER_FUNC;
float maxabs = 0.0;
for (size_t i = 0; i < num; ++i) {
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
}
if (maxabs <= SfpStream::kMax) {
return 1.0f;
}
const float scale = maxabs / SfpStream::kMax;
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
for (size_t i = 0; i < num; ++i) {
// Clamp because kMax may still be exceeded.
const float magn =
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
}
return scale;
}
} // namespace gcpp } // namespace gcpp

View File

@ -17,26 +17,19 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#include "hwy/base.h"
#define COMPRESS_STATS 0 #define COMPRESS_STATS 0
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <cstdio> #include <memory>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector> #include <vector>
// IWYU pragma: begin_exports
#include "compression/blob_store.h" #include "compression/blob_store.h"
#include "compression/fields.h" #include "compression/fields.h"
#include "compression/io.h" #include "compression/io.h"
#include "compression/shared.h" #include "compression/shared.h" // NuqStream::ClusterBuf
#include "gemma/tensor_index.h"
#include "util/basics.h" #include "util/basics.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "gemma/configs.h" #include "gemma/configs.h"
@ -174,7 +167,8 @@ struct CompressStats {
#endif // COMPRESS_STATS #endif // COMPRESS_STATS
struct CompressPerThread { struct CompressPerThread {
NuqStream::ClusterBuf buf; // Allocated the first time NUQ is used.
std::unique_ptr<NuqStream::ClusterBuf> buf;
CompressStats stats; CompressStats stats;
}; };
@ -375,5 +369,11 @@ class ReadFromBlobStore {
std::vector<std::string> file_keys_; std::vector<std::string> file_keys_;
}; };
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
// multiplier with which to restore the original values. This is only necessary
// before compressing to `SfpStream` and `NuqStream`.
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_

View File

@ -13,8 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Definitions shared between the public compress-inl.h interface and the // Types shared between tensor definitions and `compress-inl.h`.
// sfp-inl.h and nuq-inl.h implementation details.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
@ -63,30 +62,6 @@ struct SfpStream {
}; };
#pragma pack(pop) #pragma pack(pop)
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
// such that the largest magnitude is SfpStream::kMax, and returns the
// multiplier with which to restore the original values. This is only necessary
// before compressing to SfpStream.
// TODO: vectorize
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
float maxabs = 0.0;
for (size_t i = 0; i < num; ++i) {
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
}
if (maxabs <= SfpStream::kMax) {
return 1.0f;
}
const float scale = maxabs / SfpStream::kMax;
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
for (size_t i = 0; i < num; ++i) {
// Clamp because kMax may still be exceeded.
const float magn =
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
}
return scale;
}
// Non-uniform quantization: a compressed representation of f32 inputs that // Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or // supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
// two vectors (for `Decompress2`), and decoding to bf16/f32. // two vectors (for `Decompress2`), and decoding to bf16/f32.
@ -185,20 +160,6 @@ constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>(); return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
} }
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM,
PALIGEMMA,
kSentinel // must be last
};
inline bool EnumValid(PromptWrapping type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
}
// Tensor types for loading weights. Note that not all types are supported as // Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for // weights for a model, but can be used for other purposes, such as types for
// `WeightsPtrs`. When adding a new type that is supported, also // `WeightsPtrs`. When adding a new type that is supported, also

View File

@ -49,6 +49,20 @@ static constexpr size_t kMaxConv1DWidth = 4;
using EmbedderInputT = BF16; using EmbedderInputT = BF16;
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM,
PALIGEMMA,
kSentinel // must be last
};
static inline bool EnumValid(PromptWrapping wrapping) {
return static_cast<size_t>(wrapping) <
static_cast<size_t>(PromptWrapping::kSentinel);
}
enum class LayerAttentionType { enum class LayerAttentionType {
kGemma, kGemma,
kGriffinRecurrentBlock, kGriffinRecurrentBlock,

View File

@ -15,9 +15,6 @@
#include <stddef.h> #include <stddef.h>
#include "compression/compress.h"
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.

View File

@ -40,6 +40,7 @@ cc_test(
], ],
deps = [ deps = [
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:allocator",
"//:benchmark_helper", "//:benchmark_helper",
"//:common", "//:common",
"//:gemma_lib", "//:gemma_lib",

View File

@ -20,7 +20,9 @@
#include "compression/shared.h" #include "compression/shared.h"
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -50,17 +52,18 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
Gemma& model = *(s_env->GetGemma()); const Allocator2& allocator = s_env->Env().ctx.allocator;
image_tokens_ = Gemma& gemma = *(s_env->GetGemma());
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, image_tokens_ = ImageTokens(
model.GetModelConfig().model_dim)); allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
gemma.GetModelConfig().model_dim));
Image image; Image image;
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().vit_config.image_size; const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
model.GenerateImageTokens(runtime_config, image, image_tokens_); gemma.GenerateImageTokens(runtime_config, image, image_tokens_);
} }
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
@ -124,7 +127,7 @@ TEST_F(PaliGemmaTest, General) {
}; };
const char* (*qa)[2]; const char* (*qa)[2];
size_t num; size_t num;
switch (s_env->GetGemma()->Info().model) { switch (s_env->GetGemma()->GetModelConfig().model) {
case Model::PALIGEMMA_224: case Model::PALIGEMMA_224:
qa = kQA_3B_mix_224; qa = kQA_3B_mix_224;
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]); num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);

View File

@ -21,10 +21,10 @@ pybind_extension(
name = "gemma", name = "gemma",
srcs = ["gemma_py.cc"], srcs = ["gemma_py.cc"],
deps = [ deps = [
"//:allocator",
"//:benchmark_helper", "//:benchmark_helper",
"//:gemma_args", "//:gemma_args",
"//:gemma_lib", "//:gemma_lib",
"//:threading_context",
"//compression:shared", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
], ],

View File

@ -43,6 +43,7 @@ PYBIND11_MODULE(configs, py_module) {
enum_<PromptWrapping>(py_module, "PromptWrapping") enum_<PromptWrapping>(py_module, "PromptWrapping")
.value("GEMMA_IT", PromptWrapping::GEMMA_IT) .value("GEMMA_IT", PromptWrapping::GEMMA_IT)
.value("GEMMA_PT", PromptWrapping::GEMMA_PT) .value("GEMMA_PT", PromptWrapping::GEMMA_PT)
.value("GEMMA_VLM", PromptWrapping::GEMMA_VLM)
.value("PALIGEMMA", PromptWrapping::PALIGEMMA); .value("PALIGEMMA", PromptWrapping::PALIGEMMA);
enum_<Type>(py_module, "Type") enum_<Type>(py_module, "Type")

View File

@ -22,18 +22,16 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <random>
#include <set> #include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "compression/shared.h"
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "util/allocator.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
namespace py = pybind11; namespace py = pybind11;
@ -169,9 +167,10 @@ class GemmaModel {
// Generate* will use this image. Throws an error for other models. // Generate* will use this image. Throws an error for other models.
void SetImage(const py::array_t<float, py::array::c_style | void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) { py::array::forcecast>& image) {
gcpp::Gemma& gemma = *(gemma_.GetGemma());
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator; const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
gcpp::Gemma& model = *(gemma_.GetGemma()); if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) { gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model."); throw std::invalid_argument("Not a PaliGemma model.");
} }
py::buffer_info buffer = image.request(); py::buffer_info buffer = image.request();
@ -183,14 +182,14 @@ class GemmaModel {
float* ptr = static_cast<float*>(buffer.ptr); float* ptr = static_cast<float*>(buffer.ptr);
gcpp::Image c_image; gcpp::Image c_image;
c_image.Set(height, width, ptr); c_image.Set(height, width, ptr);
const size_t image_size = model.GetModelConfig().vit_config.image_size; const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
c_image.Resize(image_size, image_size); c_image.Resize(image_size, image_size);
image_tokens_ = gcpp::ImageTokens( image_tokens_ = gcpp::ImageTokens(
allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len, allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim)); gemma.GetModelConfig().model_dim));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0}; .verbosity = 0};
model.GenerateImageTokens(runtime_config, c_image, image_tokens_); gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_);
} }
// Generates a response to the given prompt, using the last set image. // Generates a response to the given prompt, using the last set image.
@ -267,12 +266,12 @@ PYBIND11_MODULE(gemma, mod) {
throw std::invalid_argument(err); throw std::invalid_argument(err);
} }
loader.weight_type_str = weight_type; loader.weight_type_str = weight_type;
gcpp::ThreadingArgs threading;
threading.max_lps = max_threads;
gcpp::InferenceArgs inference; gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512; inference.max_generated_tokens = 512;
gcpp::ThreadingArgs app;
app.max_threads = max_threads;
auto gemma = auto gemma =
std::make_unique<GemmaModel>(loader, inference, app); std::make_unique<GemmaModel>(loader, inference, threading);
if (!gemma->ModelIsLoaded()) { if (!gemma->ModelIsLoaded()) {
throw std::invalid_argument("Could not load model."); throw std::invalid_argument("Could not load model.");
} }

View File

@ -18,8 +18,12 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <random>
#include <vector>
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/per_target.h" // VectorBytes #include "hwy/per_target.h" // VectorBytes
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -27,8 +31,11 @@ namespace gcpp {
void CopyMat(const MatPtr& from, MatPtr& to) { void CopyMat(const MatPtr& from, MatPtr& to) {
PROFILER_FUNC; PROFILER_FUNC;
HWY_ASSERT_M(from.HasPtr() && to.HasPtr(), to.Name());
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols()); HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
HWY_ASSERT(to.GetType() == from.GetType()); HWY_ASSERT(to.GetType() == from.GetType());
to.SetScale(from.Scale());
if (to.IsPacked() && from.IsPacked()) { if (to.IsPacked() && from.IsPacked()) {
HWY_ASSERT(to.PackedBytes() == from.PackedBytes()); HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes()); hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
@ -45,6 +52,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
void ZeroInit(MatPtr& mat) { void ZeroInit(MatPtr& mat) {
PROFILER_FUNC; PROFILER_FUNC;
HWY_ASSERT_M(mat.HasPtr(), mat.Name()); HWY_ASSERT_M(mat.HasPtr(), mat.Name());
mat.SetScale(1.0f);
if (mat.IsPacked()) { if (mat.IsPacked()) {
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes()); hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
return; return;
@ -55,6 +64,31 @@ void ZeroInit(MatPtr& mat) {
} }
} }
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
PROFILER_FUNC;
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
// Only generates float/double for use by backprop/.
HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64);
mat.SetScale(1.0f);
std::normal_distribution<float> dist(0.0, stddev);
if (mat.GetType() == Type::kF32) {
for (size_t r = 0; r < mat.Rows(); ++r) {
float* HWY_RESTRICT row = mat.RowT<float>(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}
}
} else {
for (size_t r = 0; r < mat.Rows(); ++r) {
double* HWY_RESTRICT row = mat.RowT<double>(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}
}
}
}
// Returns `num` rounded up to an odd number of cache lines. This would also // Returns `num` rounded up to an odd number of cache lines. This would also
// prevent 4K aliasing and is coprime with the cache associativity, which // prevent 4K aliasing and is coprime with the cache associativity, which
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`. // might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
@ -84,6 +118,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
} }
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator2& allocator = ThreadingContext2::Get().allocator;
const size_t stride = Stride(allocator, mat, padding); const size_t stride = Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride; const size_t num = mat.Rows() * stride;
@ -97,4 +132,16 @@ void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
storage_ = allocator.AllocBytes(padded_bytes); storage_ = allocator.AllocBytes(padded_bytes);
mat.SetPtr(storage_.get(), stride); mat.SetPtr(storage_.get(), stride);
} }
void MatOwners::AllocateFor(const std::vector<MatPtr*>& mats,
MatPadding padding, hwy::ThreadPool& pool) {
const size_t start = owners_.size();
owners_.resize(start + mats.size());
// Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
owners_[start + task].AllocateFor(*mats[task], padding);
});
}
} // namespace gcpp } // namespace gcpp

View File

@ -22,6 +22,7 @@
#include <random> #include <random>
#include <string> #include <string>
#include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "compression/fields.h" #include "compression/fields.h"
@ -31,6 +32,7 @@
#include "util/basics.h" // Extents2D #include "util/basics.h" // Extents2D
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -71,7 +73,8 @@ class MatPtr : public IFields {
bool HasPtr() const { return ptr_ != nullptr; } bool HasPtr() const { return ptr_ != nullptr; }
bool IsPacked() const { return stride_ == cols_; } // A single row counts as packed because there is no padding between rows.
bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); }
const void* Packed() const { const void* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str()); HWY_DASSERT_M(IsPacked(), name_.c_str());
@ -132,11 +135,10 @@ class MatPtr : public IFields {
float Scale() const { return scale_; } float Scale() const { return scale_; }
void SetScale(float scale) { scale_ = scale; } void SetScale(float scale) { scale_ = scale; }
// Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it // A terse identifier unique across all tensors of the model.
// be <= 16 bytes including prefixes/suffixes. The initial name set by the
// ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer
// suffix, and when loading, we call `SetName` with that.
const char* Name() const override { return name_.c_str(); } const char* Name() const override { return name_.c_str(); }
// `MakeKey` in `blob_store.cc` requires that this be <= 16 bytes, including
// the `LayerSuffix` for per-layer tensors.
void SetName(const char* name) { void SetName(const char* name) {
name_ = name; name_ = name;
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name); HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
@ -194,11 +196,13 @@ class MatPtr : public IFields {
uint32_t stride_; uint32_t stride_;
}; };
// Non-type erased version of `MatPtr`. Use this when operating on the values. // Non-type erased version of `MatPtr`. Although `MatPtr` also provides
// type-aware accessors (`RowT`), this class is more convenient when accessing
// elements, and ensures the template argument and `Type` are consistent.
template <typename MatT> template <typename MatT>
class MatPtrT : public MatPtr { class MatPtrT : public MatPtr {
public: public:
// Runtime-specified shape. // Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents) MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {} : MatPtr(name, TypeEnum<MatT>(), extents) {}
// Take shape from `TensorInfo` to avoid duplicating it in the caller. // Take shape from `TensorInfo` to avoid duplicating it in the caller.
@ -247,6 +251,15 @@ class MatPtrT : public MatPtr {
HWY_ASSERT(IsPacked()); HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_); return MakeSpan(Row(0), num_elements_);
} }
// For when a span of a single row is required. This also works if padded,
// but does not support `GetType() == kNUQ`, because that requires the use of
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
// embeddings.
PackedSpan<const MatT> RowSpan(size_t row) const {
HWY_DASSERT(GetType() != Type::kNUQ);
return MakeConstSpan(Row(row), Cols());
}
}; };
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
@ -340,6 +353,25 @@ class MatOwner {
AlignedPtr2<uint8_t[]> storage_; AlignedPtr2<uint8_t[]> storage_;
}; };
// Multiple `MatOwner`, with support for parallel allocation.
class MatOwners {
public:
// Ignores `padding` for NUQ tensors, which are always packed.
void AllocateFor(MatPtr& mat, MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
owners_.push_back(MatOwner());
owners_.back().AllocateFor(mat, padding);
}
// Allocates multiple in parallel. Ignores `padding` for NUQ tensors,
// which are always packed.
void AllocateFor(const std::vector<MatPtr*>& mats, MatPadding padding,
hwy::ThreadPool& pool);
private:
std::vector<MatOwner> owners_;
};
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and // `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
// tests to allocate and access tensors of a known type. By contrast, the // tests to allocate and access tensors of a known type. By contrast, the
// heterogeneous model weights are owned by vectors of `MatOwner`. // heterogeneous model weights are owned by vectors of `MatOwner`.

View File

@ -18,6 +18,9 @@
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
static ThreadingArgs s_args; static ThreadingArgs s_args;
@ -41,6 +44,7 @@ static std::mutex s_ctx_mutex;
} }
/*static*/ ThreadingContext2& ThreadingContext2::Get() { /*static*/ ThreadingContext2& ThreadingContext2::Get() {
PROFILER_FUNC;
// We do not bother with double-checked locking because it requires an // We do not bother with double-checked locking because it requires an
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also, // atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
// callers can cache the result and call less often. // callers can cache the result and call less often.