mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into feature-prompt-flag
This commit is contained in:
commit
09dfb144c0
|
|
@ -82,7 +82,7 @@ jobs:
|
||||||
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
|
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
|
||||||
subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
|
subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
|
||||||
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
|
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
|
||||||
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
|
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
|
||||||
assert("write an email to the moon." not in output.lower());
|
assert("write an email to the moon." not in output.lower());
|
||||||
assert("moon" in output.lower());
|
assert("moon" in output.lower());
|
||||||
EOF
|
EOF
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,8 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":threading",
|
":threading",
|
||||||
":topology",
|
":topology",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -180,6 +182,7 @@ cc_library(
|
||||||
"//compression:shared",
|
"//compression:shared",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -664,6 +667,7 @@ cc_test(
|
||||||
":mat",
|
":mat",
|
||||||
":prompt",
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ cc_library(
|
||||||
hdrs = ["blob_store.h"],
|
hdrs = ["blob_store.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":io",
|
":io",
|
||||||
|
"//:basics",
|
||||||
"//:threading_context",
|
"//:threading_context",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -130,7 +131,6 @@ cc_library(
|
||||||
textual_hdrs = ["sfp-inl.h"],
|
textual_hdrs = ["sfp-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":shared",
|
":shared",
|
||||||
"//:basics",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -195,7 +195,6 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -225,6 +224,7 @@ cc_library(
|
||||||
"//:mat",
|
"//:mat",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
|
"@highway//:profiler",
|
||||||
"@highway//:stats",
|
"@highway//:stats",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
|
|
@ -259,6 +259,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
|
":shared",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:stats",
|
"@highway//:stats",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
|
|
@ -35,6 +34,10 @@
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
#if COMPRESS_STATS
|
||||||
|
#include <cmath> // lroundf
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
@ -388,7 +391,7 @@ struct CompressTraits<SfpStream> {
|
||||||
const size_t packed_ofs) {
|
const size_t packed_ofs) {
|
||||||
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
|
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
|
||||||
|
|
||||||
if (COMPRESS_STATS) {
|
if constexpr (COMPRESS_STATS) {
|
||||||
const hn::Repartition<BF16, DF> dbf;
|
const hn::Repartition<BF16, DF> dbf;
|
||||||
auto distorted =
|
auto distorted =
|
||||||
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
|
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
|
||||||
|
|
@ -432,9 +435,10 @@ struct CompressTraits<NuqStream> {
|
||||||
size_t num, CompressPerThread& tls,
|
size_t num, CompressPerThread& tls,
|
||||||
const PackedSpan<Packed>& packed,
|
const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs) {
|
const size_t packed_ofs) {
|
||||||
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
|
if (!tls.buf) tls.buf = std::make_unique<NuqStream::ClusterBuf>();
|
||||||
|
NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs);
|
||||||
|
|
||||||
if (COMPRESS_STATS) {
|
if constexpr (COMPRESS_STATS) {
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
|
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
|
||||||
}
|
}
|
||||||
|
|
@ -478,7 +482,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
const size_t packed_ofs, hwy::ThreadPool& pool) {
|
const size_t packed_ofs, hwy::ThreadPool& pool) {
|
||||||
packed.BoundsCheck(packed_ofs, num);
|
packed.BoundsCheck(packed_ofs, num);
|
||||||
work.tls.resize(pool.NumWorkers());
|
work.tls.resize(pool.NumWorkers());
|
||||||
if (COMPRESS_STATS) {
|
if constexpr (COMPRESS_STATS) {
|
||||||
for (auto& tls : work.tls) {
|
for (auto& tls : work.tls) {
|
||||||
tls.stats.Reset();
|
tls.stats.Reset();
|
||||||
}
|
}
|
||||||
|
|
@ -487,7 +491,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
const bool want_bench = COMPRESS_STATS || !kIsTest;
|
const bool want_bench = COMPRESS_STATS || !kIsTest;
|
||||||
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
||||||
|
|
||||||
using Traits = CompressTraits<Packed>;
|
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||||
constexpr size_t kBatch = 8192;
|
constexpr size_t kBatch = 8192;
|
||||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||||
pool.Run(0, num_batches,
|
pool.Run(0, num_batches,
|
||||||
|
|
@ -508,7 +512,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (COMPRESS_STATS) {
|
if constexpr (COMPRESS_STATS) {
|
||||||
for (size_t i = 1; i < work.tls.size(); ++i) {
|
for (size_t i = 1; i < work.tls.size(); ++i) {
|
||||||
work.tls[0].stats.Assimilate(work.tls[i].stats);
|
work.tls[0].stats.Assimilate(work.tls[i].stats);
|
||||||
}
|
}
|
||||||
|
|
@ -534,7 +538,7 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs) {
|
const size_t packed_ofs) {
|
||||||
static_assert(hwy::IsSameEither<Packed, float, BF16>());
|
static_assert(hwy::IsSameEither<Packed, float, BF16>());
|
||||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
|
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
|
||||||
using Traits = CompressTraits<Packed>;
|
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,34 @@
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "util/mat.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// TODO: move ScaleWeights here.
|
float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
|
||||||
|
float maxabs = 0.0;
|
||||||
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
|
||||||
|
}
|
||||||
|
if (maxabs <= SfpStream::kMax) {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
const float scale = maxabs / SfpStream::kMax;
|
||||||
|
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
|
||||||
|
for (size_t i = 0; i < num; ++i) {
|
||||||
|
// Clamp because kMax may still be exceeded.
|
||||||
|
const float magn =
|
||||||
|
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
|
||||||
|
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
|
||||||
|
}
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -17,26 +17,19 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||||
|
|
||||||
#include "hwy/base.h"
|
|
||||||
#define COMPRESS_STATS 0
|
#define COMPRESS_STATS 0
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <cstdio>
|
#include <memory>
|
||||||
#include <cstring>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
#include "compression/fields.h"
|
#include "compression/fields.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h" // NuqStream::ClusterBuf
|
||||||
#include "gemma/tensor_index.h"
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
|
@ -174,7 +167,8 @@ struct CompressStats {
|
||||||
#endif // COMPRESS_STATS
|
#endif // COMPRESS_STATS
|
||||||
|
|
||||||
struct CompressPerThread {
|
struct CompressPerThread {
|
||||||
NuqStream::ClusterBuf buf;
|
// Allocated the first time NUQ is used.
|
||||||
|
std::unique_ptr<NuqStream::ClusterBuf> buf;
|
||||||
CompressStats stats;
|
CompressStats stats;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -375,5 +369,11 @@ class ReadFromBlobStore {
|
||||||
std::vector<std::string> file_keys_;
|
std::vector<std::string> file_keys_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
|
||||||
|
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
|
||||||
|
// multiplier with which to restore the original values. This is only necessary
|
||||||
|
// before compressing to `SfpStream` and `NuqStream`.
|
||||||
|
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Definitions shared between the public compress-inl.h interface and the
|
// Types shared between tensor definitions and `compress-inl.h`.
|
||||||
// sfp-inl.h and nuq-inl.h implementation details.
|
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||||
|
|
@ -63,30 +62,6 @@ struct SfpStream {
|
||||||
};
|
};
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
|
|
||||||
// such that the largest magnitude is SfpStream::kMax, and returns the
|
|
||||||
// multiplier with which to restore the original values. This is only necessary
|
|
||||||
// before compressing to SfpStream.
|
|
||||||
// TODO: vectorize
|
|
||||||
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
|
|
||||||
float maxabs = 0.0;
|
|
||||||
for (size_t i = 0; i < num; ++i) {
|
|
||||||
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
|
|
||||||
}
|
|
||||||
if (maxabs <= SfpStream::kMax) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
const float scale = maxabs / SfpStream::kMax;
|
|
||||||
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
|
|
||||||
for (size_t i = 0; i < num; ++i) {
|
|
||||||
// Clamp because kMax may still be exceeded.
|
|
||||||
const float magn =
|
|
||||||
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
|
|
||||||
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
|
|
||||||
}
|
|
||||||
return scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||||
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
||||||
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
||||||
|
|
@ -185,20 +160,6 @@ constexpr bool IsNuqStream() {
|
||||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
|
||||||
enum class PromptWrapping {
|
|
||||||
GEMMA_IT,
|
|
||||||
GEMMA_PT,
|
|
||||||
GEMMA_VLM,
|
|
||||||
PALIGEMMA,
|
|
||||||
kSentinel // must be last
|
|
||||||
};
|
|
||||||
|
|
||||||
inline bool EnumValid(PromptWrapping type) {
|
|
||||||
return static_cast<int>(type) >= 0 &&
|
|
||||||
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tensor types for loading weights. Note that not all types are supported as
|
// Tensor types for loading weights. Note that not all types are supported as
|
||||||
// weights for a model, but can be used for other purposes, such as types for
|
// weights for a model, but can be used for other purposes, such as types for
|
||||||
// `WeightsPtrs`. When adding a new type that is supported, also
|
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,20 @@ static constexpr size_t kMaxConv1DWidth = 4;
|
||||||
|
|
||||||
using EmbedderInputT = BF16;
|
using EmbedderInputT = BF16;
|
||||||
|
|
||||||
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
|
enum class PromptWrapping {
|
||||||
|
GEMMA_IT,
|
||||||
|
GEMMA_PT,
|
||||||
|
GEMMA_VLM,
|
||||||
|
PALIGEMMA,
|
||||||
|
kSentinel // must be last
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline bool EnumValid(PromptWrapping wrapping) {
|
||||||
|
return static_cast<size_t>(wrapping) <
|
||||||
|
static_cast<size_t>(PromptWrapping::kSentinel);
|
||||||
|
}
|
||||||
|
|
||||||
enum class LayerAttentionType {
|
enum class LayerAttentionType {
|
||||||
kGemma,
|
kGemma,
|
||||||
kGriffinRecurrentBlock,
|
kGriffinRecurrentBlock,
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,6 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "util/mat.h"
|
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//:allocator",
|
||||||
"//:benchmark_helper",
|
"//:benchmark_helper",
|
||||||
"//:common",
|
"//:common",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
|
@ -50,17 +52,18 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
|
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
Gemma& model = *(s_env->GetGemma());
|
const Allocator2& allocator = s_env->Env().ctx.allocator;
|
||||||
image_tokens_ =
|
Gemma& gemma = *(s_env->GetGemma());
|
||||||
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
image_tokens_ = ImageTokens(
|
||||||
model.GetModelConfig().model_dim));
|
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
|
||||||
|
gemma.GetModelConfig().model_dim));
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
||||||
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
gemma.GenerateImageTokens(runtime_config, image, image_tokens_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||||
|
|
@ -124,7 +127,7 @@ TEST_F(PaliGemmaTest, General) {
|
||||||
};
|
};
|
||||||
const char* (*qa)[2];
|
const char* (*qa)[2];
|
||||||
size_t num;
|
size_t num;
|
||||||
switch (s_env->GetGemma()->Info().model) {
|
switch (s_env->GetGemma()->GetModelConfig().model) {
|
||||||
case Model::PALIGEMMA_224:
|
case Model::PALIGEMMA_224:
|
||||||
qa = kQA_3B_mix_224;
|
qa = kQA_3B_mix_224;
|
||||||
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,10 @@ pybind_extension(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = ["gemma_py.cc"],
|
srcs = ["gemma_py.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//:allocator",
|
|
||||||
"//:benchmark_helper",
|
"//:benchmark_helper",
|
||||||
"//:gemma_args",
|
"//:gemma_args",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
|
"//:threading_context",
|
||||||
"//compression:shared",
|
"//compression:shared",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ PYBIND11_MODULE(configs, py_module) {
|
||||||
enum_<PromptWrapping>(py_module, "PromptWrapping")
|
enum_<PromptWrapping>(py_module, "PromptWrapping")
|
||||||
.value("GEMMA_IT", PromptWrapping::GEMMA_IT)
|
.value("GEMMA_IT", PromptWrapping::GEMMA_IT)
|
||||||
.value("GEMMA_PT", PromptWrapping::GEMMA_PT)
|
.value("GEMMA_PT", PromptWrapping::GEMMA_PT)
|
||||||
|
.value("GEMMA_VLM", PromptWrapping::GEMMA_VLM)
|
||||||
.value("PALIGEMMA", PromptWrapping::PALIGEMMA);
|
.value("PALIGEMMA", PromptWrapping::PALIGEMMA);
|
||||||
|
|
||||||
enum_<Type>(py_module, "Type")
|
enum_<Type>(py_module, "Type")
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,16 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/shared.h"
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "util/allocator.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
@ -169,9 +167,10 @@ class GemmaModel {
|
||||||
// Generate* will use this image. Throws an error for other models.
|
// Generate* will use this image. Throws an error for other models.
|
||||||
void SetImage(const py::array_t<float, py::array::c_style |
|
void SetImage(const py::array_t<float, py::array::c_style |
|
||||||
py::array::forcecast>& image) {
|
py::array::forcecast>& image) {
|
||||||
|
gcpp::Gemma& gemma = *(gemma_.GetGemma());
|
||||||
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
||||||
gcpp::Gemma& model = *(gemma_.GetGemma());
|
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||||
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
|
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||||
throw std::invalid_argument("Not a PaliGemma model.");
|
throw std::invalid_argument("Not a PaliGemma model.");
|
||||||
}
|
}
|
||||||
py::buffer_info buffer = image.request();
|
py::buffer_info buffer = image.request();
|
||||||
|
|
@ -183,14 +182,14 @@ class GemmaModel {
|
||||||
float* ptr = static_cast<float*>(buffer.ptr);
|
float* ptr = static_cast<float*>(buffer.ptr);
|
||||||
gcpp::Image c_image;
|
gcpp::Image c_image;
|
||||||
c_image.Set(height, width, ptr);
|
c_image.Set(height, width, ptr);
|
||||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
|
||||||
c_image.Resize(image_size, image_size);
|
c_image.Resize(image_size, image_size);
|
||||||
image_tokens_ = gcpp::ImageTokens(
|
image_tokens_ = gcpp::ImageTokens(
|
||||||
allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len,
|
allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len,
|
||||||
model.GetModelConfig().model_dim));
|
gemma.GetModelConfig().model_dim));
|
||||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||||
.verbosity = 0};
|
.verbosity = 0};
|
||||||
model.GenerateImageTokens(runtime_config, c_image, image_tokens_);
|
gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates a response to the given prompt, using the last set image.
|
// Generates a response to the given prompt, using the last set image.
|
||||||
|
|
@ -267,12 +266,12 @@ PYBIND11_MODULE(gemma, mod) {
|
||||||
throw std::invalid_argument(err);
|
throw std::invalid_argument(err);
|
||||||
}
|
}
|
||||||
loader.weight_type_str = weight_type;
|
loader.weight_type_str = weight_type;
|
||||||
|
gcpp::ThreadingArgs threading;
|
||||||
|
threading.max_lps = max_threads;
|
||||||
gcpp::InferenceArgs inference;
|
gcpp::InferenceArgs inference;
|
||||||
inference.max_generated_tokens = 512;
|
inference.max_generated_tokens = 512;
|
||||||
gcpp::ThreadingArgs app;
|
|
||||||
app.max_threads = max_threads;
|
|
||||||
auto gemma =
|
auto gemma =
|
||||||
std::make_unique<GemmaModel>(loader, inference, app);
|
std::make_unique<GemmaModel>(loader, inference, threading);
|
||||||
if (!gemma->ModelIsLoaded()) {
|
if (!gemma->ModelIsLoaded()) {
|
||||||
throw std::invalid_argument("Could not load model.");
|
throw std::invalid_argument("Could not load model.");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
47
util/mat.cc
47
util/mat.cc
|
|
@ -18,8 +18,12 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/per_target.h" // VectorBytes
|
#include "hwy/per_target.h" // VectorBytes
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -27,8 +31,11 @@ namespace gcpp {
|
||||||
|
|
||||||
void CopyMat(const MatPtr& from, MatPtr& to) {
|
void CopyMat(const MatPtr& from, MatPtr& to) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
|
HWY_ASSERT_M(from.HasPtr() && to.HasPtr(), to.Name());
|
||||||
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
|
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
|
||||||
HWY_ASSERT(to.GetType() == from.GetType());
|
HWY_ASSERT(to.GetType() == from.GetType());
|
||||||
|
to.SetScale(from.Scale());
|
||||||
|
|
||||||
if (to.IsPacked() && from.IsPacked()) {
|
if (to.IsPacked() && from.IsPacked()) {
|
||||||
HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
|
HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
|
||||||
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
|
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
|
||||||
|
|
@ -45,6 +52,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
|
||||||
void ZeroInit(MatPtr& mat) {
|
void ZeroInit(MatPtr& mat) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
||||||
|
mat.SetScale(1.0f);
|
||||||
|
|
||||||
if (mat.IsPacked()) {
|
if (mat.IsPacked()) {
|
||||||
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
|
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
|
||||||
return;
|
return;
|
||||||
|
|
@ -55,6 +64,31 @@ void ZeroInit(MatPtr& mat) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
||||||
|
// Only generates float/double for use by backprop/.
|
||||||
|
HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64);
|
||||||
|
mat.SetScale(1.0f);
|
||||||
|
|
||||||
|
std::normal_distribution<float> dist(0.0, stddev);
|
||||||
|
if (mat.GetType() == Type::kF32) {
|
||||||
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
|
float* HWY_RESTRICT row = mat.RowT<float>(r);
|
||||||
|
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||||
|
row[c] = dist(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
|
double* HWY_RESTRICT row = mat.RowT<double>(r);
|
||||||
|
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||||
|
row[c] = dist(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns `num` rounded up to an odd number of cache lines. This would also
|
// Returns `num` rounded up to an odd number of cache lines. This would also
|
||||||
// prevent 4K aliasing and is coprime with the cache associativity, which
|
// prevent 4K aliasing and is coprime with the cache associativity, which
|
||||||
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
|
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
|
||||||
|
|
@ -84,6 +118,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
|
||||||
}
|
}
|
||||||
|
|
||||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||||
|
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
|
||||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||||
const size_t stride = Stride(allocator, mat, padding);
|
const size_t stride = Stride(allocator, mat, padding);
|
||||||
const size_t num = mat.Rows() * stride;
|
const size_t num = mat.Rows() * stride;
|
||||||
|
|
@ -97,4 +132,16 @@ void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||||
storage_ = allocator.AllocBytes(padded_bytes);
|
storage_ = allocator.AllocBytes(padded_bytes);
|
||||||
mat.SetPtr(storage_.get(), stride);
|
mat.SetPtr(storage_.get(), stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MatOwners::AllocateFor(const std::vector<MatPtr*>& mats,
|
||||||
|
MatPadding padding, hwy::ThreadPool& pool) {
|
||||||
|
const size_t start = owners_.size();
|
||||||
|
owners_.resize(start + mats.size());
|
||||||
|
|
||||||
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
|
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
|
owners_[start + task].AllocateFor(*mats[task], padding);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
46
util/mat.h
46
util/mat.h
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/fields.h"
|
#include "compression/fields.h"
|
||||||
|
|
@ -31,6 +32,7 @@
|
||||||
#include "util/basics.h" // Extents2D
|
#include "util/basics.h" // Extents2D
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -71,7 +73,8 @@ class MatPtr : public IFields {
|
||||||
|
|
||||||
bool HasPtr() const { return ptr_ != nullptr; }
|
bool HasPtr() const { return ptr_ != nullptr; }
|
||||||
|
|
||||||
bool IsPacked() const { return stride_ == cols_; }
|
// A single row counts as packed because there is no padding between rows.
|
||||||
|
bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); }
|
||||||
|
|
||||||
const void* Packed() const {
|
const void* Packed() const {
|
||||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
|
@ -132,11 +135,10 @@ class MatPtr : public IFields {
|
||||||
float Scale() const { return scale_; }
|
float Scale() const { return scale_; }
|
||||||
void SetScale(float scale) { scale_ = scale; }
|
void SetScale(float scale) { scale_ = scale; }
|
||||||
|
|
||||||
// Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it
|
// A terse identifier unique across all tensors of the model.
|
||||||
// be <= 16 bytes including prefixes/suffixes. The initial name set by the
|
|
||||||
// ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer
|
|
||||||
// suffix, and when loading, we call `SetName` with that.
|
|
||||||
const char* Name() const override { return name_.c_str(); }
|
const char* Name() const override { return name_.c_str(); }
|
||||||
|
// `MakeKey` in `blob_store.cc` requires that this be <= 16 bytes, including
|
||||||
|
// the `LayerSuffix` for per-layer tensors.
|
||||||
void SetName(const char* name) {
|
void SetName(const char* name) {
|
||||||
name_ = name;
|
name_ = name;
|
||||||
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
|
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
|
||||||
|
|
@ -194,11 +196,13 @@ class MatPtr : public IFields {
|
||||||
uint32_t stride_;
|
uint32_t stride_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Non-type erased version of `MatPtr`. Use this when operating on the values.
|
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
|
||||||
|
// type-aware accessors (`RowT`), this class is more convenient when accessing
|
||||||
|
// elements, and ensures the template argument and `Type` are consistent.
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
class MatPtrT : public MatPtr {
|
class MatPtrT : public MatPtr {
|
||||||
public:
|
public:
|
||||||
// Runtime-specified shape.
|
// Called by `MatStorageT`.
|
||||||
MatPtrT(const char* name, Extents2D extents)
|
MatPtrT(const char* name, Extents2D extents)
|
||||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||||
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
|
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
|
||||||
|
|
@ -247,6 +251,15 @@ class MatPtrT : public MatPtr {
|
||||||
HWY_ASSERT(IsPacked());
|
HWY_ASSERT(IsPacked());
|
||||||
return MakeSpan(Row(0), num_elements_);
|
return MakeSpan(Row(0), num_elements_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For when a span of a single row is required. This also works if padded,
|
||||||
|
// but does not support `GetType() == kNUQ`, because that requires the use of
|
||||||
|
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
|
||||||
|
// embeddings.
|
||||||
|
PackedSpan<const MatT> RowSpan(size_t row) const {
|
||||||
|
HWY_DASSERT(GetType() != Type::kNUQ);
|
||||||
|
return MakeConstSpan(Row(row), Cols());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||||
|
|
@ -340,6 +353,25 @@ class MatOwner {
|
||||||
AlignedPtr2<uint8_t[]> storage_;
|
AlignedPtr2<uint8_t[]> storage_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Multiple `MatOwner`, with support for parallel allocation.
|
||||||
|
class MatOwners {
|
||||||
|
public:
|
||||||
|
// Ignores `padding` for NUQ tensors, which are always packed.
|
||||||
|
void AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||||
|
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
|
||||||
|
owners_.push_back(MatOwner());
|
||||||
|
owners_.back().AllocateFor(mat, padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocates multiple in parallel. Ignores `padding` for NUQ tensors,
|
||||||
|
// which are always packed.
|
||||||
|
void AllocateFor(const std::vector<MatPtr*>& mats, MatPadding padding,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<MatOwner> owners_;
|
||||||
|
};
|
||||||
|
|
||||||
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
||||||
// tests to allocate and access tensors of a known type. By contrast, the
|
// tests to allocate and access tensors of a known type. By contrast, the
|
||||||
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex> // NOLINT
|
#include <mutex> // NOLINT
|
||||||
|
|
||||||
|
#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static ThreadingArgs s_args;
|
static ThreadingArgs s_args;
|
||||||
|
|
@ -41,6 +44,7 @@ static std::mutex s_ctx_mutex;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
||||||
|
PROFILER_FUNC;
|
||||||
// We do not bother with double-checked locking because it requires an
|
// We do not bother with double-checked locking because it requires an
|
||||||
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
||||||
// callers can cache the result and call less often.
|
// callers can cache the result and call less often.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue