mirror of https://github.com/google/gemma.cpp.git
Infra improvements (2)
ops.h: move CreateInvTimescale to allow calling without depending on gemma Pass around MatMulEnv instead of pools to avoid re-creating the env profiler.h can now be used outside SIMD code allocator: add StepBytes and QuantumSteps rename worker thread with package/cluster in the name threading: add Visit* to IndexRange PiperOrigin-RevId: 718766704
This commit is contained in:
parent
f37402da57
commit
a60b564b88
11
BUILD.bazel
11
BUILD.bazel
|
|
@ -86,6 +86,7 @@ cc_library(
|
|||
name = "ops",
|
||||
hdrs = [
|
||||
"ops/matmul.h",
|
||||
"ops/ops.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"ops/dot-inl.h",
|
||||
|
|
@ -104,6 +105,7 @@ cc_library(
|
|||
"@highway//:hwy",
|
||||
"@highway//:math",
|
||||
"@highway//:matvec",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -145,7 +147,6 @@ cc_test(
|
|||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":test_util",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
|
|
@ -221,8 +222,11 @@ cc_test(
|
|||
timeout = "long",
|
||||
srcs = ["ops/bench_matmul.cc"],
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
"ops_tests", # for test_suite.
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
|
|
@ -683,6 +687,7 @@ cc_test(
|
|||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":optimizer",
|
||||
":prompt",
|
||||
":sampler",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2b565e87d50b151660494624af532ac0b6076c79 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
@ -94,6 +94,7 @@ set(SOURCES
|
|||
ops/matmul-inl.h
|
||||
ops/matvec-inl.h
|
||||
ops/ops-inl.h
|
||||
ops/ops.h
|
||||
ops/sum-inl.h
|
||||
paligemma/image.cc
|
||||
paligemma/image.h
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "2b565e87d50b151660494624af532ac0b6076c79",
|
||||
commit = "f2209b911c74019e85d0b7a7a2833c9a2e1b7995",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@
|
|||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -223,8 +223,9 @@ void TestEndToEnd() {
|
|||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0f, gen);
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@
|
|||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
|
|
@ -62,8 +62,9 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
ForwardPass<float> forward(config), backward(config);
|
||||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, pools);
|
||||
|
||||
|
|
|
|||
|
|
@ -303,6 +303,42 @@ struct CompressTraits<BF16> {
|
|||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
BF16* HWY_RESTRICT raw, size_t num) {
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
#if 1
|
||||
const BF16* HWY_RESTRICT start = packed.ptr + packed_ofs;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
size_t i = 0;
|
||||
if (num >= 4 * N16) {
|
||||
for (; i <= num - 4 * N16; i += 4 * N16) {
|
||||
const VBF packed0 = hn::LoadU(dbf, start + i + 0 * N16);
|
||||
const VBF packed1 = hn::LoadU(dbf, start + i + 1 * N16);
|
||||
const VBF packed2 = hn::LoadU(dbf, start + i + 2 * N16);
|
||||
const VBF packed3 = hn::LoadU(dbf, start + i + 3 * N16);
|
||||
hn::StoreU(packed0, dbf, raw + i + 0 * N16);
|
||||
hn::StoreU(packed1, dbf, raw + i + 1 * N16);
|
||||
hn::StoreU(packed2, dbf, raw + i + 2 * N16);
|
||||
hn::StoreU(packed3, dbf, raw + i + 3 * N16);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < num; i += N16) {
|
||||
const size_t remaining = num - i;
|
||||
const VBF packed0 = hn::LoadN(dbf, start + i, remaining);
|
||||
hn::StoreU(packed0, dbf, raw + i);
|
||||
}
|
||||
#else
|
||||
hwy::CopyBytes(packed.ptr + packed_ofs, raw, num * sizeof(BF16));
|
||||
hwy::ZeroBytes(raw + num, (hwy::RoundUpTo(num, N16) - num) * sizeof(BF16));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
|
|
@ -534,7 +570,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
|||
// also wants to scale the decompressed elements.
|
||||
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
|
||||
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||
HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, TRaw* raw,
|
||||
size_t num) {
|
||||
detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ struct PackedSpan {
|
|||
}
|
||||
|
||||
Packed* HWY_RESTRICT ptr;
|
||||
size_t num; // for BoundsCheck and nuq-inl.h HWY_ASSERT.
|
||||
size_t num; // for BoundsCheck, also required by nuq-inl.h.
|
||||
};
|
||||
|
||||
// Avoids spelling out the template parameter in every call.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ project(hello_world)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2b565e87d50b151660494624af532ac0b6076c79)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,10 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory> // std::unique_ptr
|
||||
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_DASSERT
|
||||
|
|
@ -65,7 +63,7 @@ struct Activations {
|
|||
RowVectorBatch<float> inv_timescale;
|
||||
|
||||
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||
std::unique_ptr<MatMulEnv> env;
|
||||
MatMulEnv* env;
|
||||
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// And the config.
|
||||
|
|
@ -74,23 +72,7 @@ struct Activations {
|
|||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
|
||||
static RowVectorBatch<float> CreateInvTimescale(
|
||||
size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) {
|
||||
const size_t rope_dim =
|
||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const double freq_exponents =
|
||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||
// noticeably.
|
||||
inv_timescale.Batch(0)[dim] =
|
||||
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||
}
|
||||
return inv_timescale;
|
||||
}
|
||||
|
||||
void Allocate(size_t batch_size, NestedPools& pools) {
|
||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
|
@ -124,9 +106,10 @@ struct Activations {
|
|||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(qkv_dim, post_qk);
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
||||
post_qk == PostQKType::HalfRope);
|
||||
|
||||
env = std::make_unique<MatMulEnv>(pools);
|
||||
this->env = env;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@
|
|||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// Include guard (still compiled once per target)
|
||||
|
|
@ -53,7 +54,6 @@
|
|||
#include "ops/matmul-inl.h"
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
#ifndef GEMMA_TYPE
|
||||
#if HWY_IDE
|
||||
|
|
@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
hwy::ThreadPool& pool = activations.env->Pool();
|
||||
hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
|
|
@ -517,7 +517,7 @@ class GemmaAttention {
|
|||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(activations.env->Pool()) {
|
||||
pool_(activations.env->parallel.Pools().Pool(0)) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
|
|
@ -654,7 +654,7 @@ class VitAttention {
|
|||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
pool_(activations.env->Pool()) {}
|
||||
pool_(activations.env->parallel.Pools().Pool(0)) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
|
|
@ -1072,10 +1072,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
// then use the above. For now, we rely on MatVecAdd instead.
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.data_scale1(),
|
||||
activations.x.Batch(i), activations.env->Pool());
|
||||
MatVecAdd(
|
||||
weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(),
|
||||
activations.x.Batch(i), activations.env->parallel.Pools().Pool(0));
|
||||
}
|
||||
// Add position embeddings.
|
||||
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
||||
|
|
@ -1283,7 +1283,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
Activations prefill_activations(weights.weights_config);
|
||||
if (use_prefill_activations) {
|
||||
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
|
||||
activations.env->Pools());
|
||||
activations.env);
|
||||
}
|
||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, weights,
|
||||
|
|
@ -1354,14 +1354,14 @@ template <typename T>
|
|||
void GenerateSingleT(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, NestedPools& pools,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(model.Config());
|
||||
activations.Allocate(kNumQueries, pools);
|
||||
activations.Allocate(kNumQueries, env);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
|
|
@ -1378,7 +1378,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
|||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, NestedPools& pools,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -1393,7 +1393,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
|||
}
|
||||
|
||||
Activations activations(model.Config());
|
||||
activations.Allocate(max_qbatch_size, pools);
|
||||
activations.Allocate(max_qbatch_size, env);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
|
|
@ -1415,7 +1415,7 @@ template <typename T>
|
|||
void GenerateImageTokensT(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
NestedPools& pools) {
|
||||
MatMulEnv* env) {
|
||||
if (model.Config().vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
|
|
@ -1423,7 +1423,7 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
||||
Activations prefill_activations(vit_config);
|
||||
prefill_activations.Allocate(vit_config.seq_len, pools);
|
||||
prefill_activations.Allocate(vit_config.seq_len, env);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
|
||||
image_tokens, prefill_activations);
|
||||
|
|
@ -1438,11 +1438,10 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_TYPE, const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
size_t prefix_end, KVCache& kv_cache, NestedPools& pools,
|
||||
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>)
|
||||
(model, runtime_config, prompt, pos, prefix_end, kv_cache, pools,
|
||||
timing_info);
|
||||
(model, runtime_config, prompt, pos, prefix_end, kv_cache, env, timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
|
|
@ -1450,18 +1449,18 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
|||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
NestedPools& pools, TimingInfo& timing_info) {
|
||||
MatMulEnv* env, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>)
|
||||
(model, runtime_config, queries_prompt, queries_pos, queries_prefix_end,
|
||||
kv_caches, pools, timing_info);
|
||||
kv_caches, env, timing_info);
|
||||
}
|
||||
|
||||
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_TYPE, const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, NestedPools& pools) {
|
||||
ImageTokens& image_tokens, MatMulEnv* env) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
|
||||
(model, runtime_config, image, image_tokens, pools);
|
||||
(model, runtime_config, image, image_tokens, env);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -41,23 +41,24 @@ namespace gcpp {
|
|||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, NestedPools& pools)
|
||||
: pools_(pools), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(),
|
||||
: env_(pools), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||
env_.parallel.Pools().Pool(0),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) {
|
||||
Gemma::Gemma(const Path& weights, NestedPools& pools) : env_(pools) {
|
||||
std::string tokenizer_proto;
|
||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
pools_.Pool(), &tokenizer_proto);
|
||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
NestedPools& pools)
|
||||
: pools_(pools), tokenizer_(std::move(tokenizer)) {
|
||||
: env_(pools), tokenizer_(std::move(tokenizer)) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, pools_.Pool());
|
||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
|
|
@ -72,16 +73,16 @@ Gemma::~Gemma() {
|
|||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
size_t prefix_end, KVCache& kv_cache, \
|
||||
NestedPools& pools, TimingInfo& timing_info); \
|
||||
MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens( \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, NestedPools& pools);
|
||||
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
GEMMA_DECLARE(float)
|
||||
GEMMA_DECLARE(BF16)
|
||||
GEMMA_DECLARE(NuqStream)
|
||||
|
|
@ -93,10 +94,10 @@ struct GenerateSingleT {
|
|||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, NestedPools& pools,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, pools, timing_info);
|
||||
kv_cache, env, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -107,10 +108,10 @@ struct GenerateBatchT {
|
|||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, NestedPools& pools,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, pools, timing_info);
|
||||
queries_prefix_end, kv_caches, env, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -118,21 +119,21 @@ template <class TConfig>
|
|||
struct GenerateImageTokensT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, NestedPools& pools) const {
|
||||
ImageTokens& image_tokens, MatMulEnv* env) const {
|
||||
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
|
||||
pools);
|
||||
env);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
||||
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
|
|
@ -149,23 +150,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, pools_, timing_info);
|
||||
kv_caches, &env_, timing_info);
|
||||
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, pools_);
|
||||
image_tokens, &env_);
|
||||
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
|
|
@ -246,7 +247,7 @@ class Gemma {
|
|||
const Image& image, ImageTokens& image_tokens);
|
||||
|
||||
private:
|
||||
NestedPools& pools_;
|
||||
MatMulEnv env_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@
|
|||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -53,7 +54,6 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -119,24 +119,24 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||
std::vector<double>& times) {
|
||||
std::sort(times.begin(), times.end());
|
||||
// Many measurements are with suboptimal configs, so report the best like
|
||||
// bench_dnn, but also the ratio to the 3rd best.
|
||||
const double elapsed = times[0];
|
||||
const double ratio = times[2] / HWY_MAX(elapsed, 1E-6);
|
||||
// bench_dnn reports the best and average, but the median seems more
|
||||
// consistent and resistant to outliers.
|
||||
const double elapsed = times[times.size() / 2];
|
||||
const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0
|
||||
|
||||
const size_t num_b = B_extents.Area();
|
||||
// 2x because of FMA.
|
||||
fprintf(stderr, "%.1f\t%.2f\n", 2 * 1E-9 * A_extents.rows * num_b / elapsed,
|
||||
ratio);
|
||||
// FMA counts as two FLOP.
|
||||
fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n",
|
||||
2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio);
|
||||
}
|
||||
|
||||
// Generates inputs and prints observed throughput of MatMul.
|
||||
// M = A rows, K = A cols, N = C cols.
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", M,
|
||||
K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
|
||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
||||
fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
M, K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
|
||||
|
||||
const Extents2D A_extents(M, K);
|
||||
const Extents2D B_extents(N, K); // already transposed
|
||||
|
|
@ -161,17 +161,26 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||
|
||||
constexpr size_t kSamples = 20;
|
||||
std::vector<double> times;
|
||||
times.reserve(20);
|
||||
double result = 0.0;
|
||||
for (;;) {
|
||||
times.reserve(kSamples);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
env.parallel.Pools().MaybeStartSpinning(use_spinning);
|
||||
|
||||
double keep = 0.0;
|
||||
// Until enough samples collected *after* autotuning finished:
|
||||
while (times.size() < kSamples) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
MatMul(A, B, add_row, env, C);
|
||||
times.push_back(hwy::platform::Now() - t0);
|
||||
result += C.Row(0)[hwy::Unpredictable1()];
|
||||
if (times.size() >= 20) break;
|
||||
const double t1 = hwy::platform::Now();
|
||||
double elapsed = t1 - t0;
|
||||
keep += C.Row(0)[hwy::Unpredictable1()];
|
||||
|
||||
times.push_back(elapsed);
|
||||
}
|
||||
hwy::PreventElision(result);
|
||||
hwy::PreventElision(keep);
|
||||
env.parallel.Pools().MaybeStopSpinning(use_spinning);
|
||||
PrintSpeed(A_extents, B_extents, times);
|
||||
}
|
||||
|
||||
|
|
@ -182,7 +191,7 @@ void BenchAllMatMul() {
|
|||
if (first_target == 0) first_target = HWY_TARGET;
|
||||
if (HWY_TARGET != first_target) return;
|
||||
|
||||
for (size_t max_packages : {1, 2}) {
|
||||
for (size_t max_packages : {/*1,*/ 2}) {
|
||||
const size_t max_threads = 0; // no limit
|
||||
NestedPools pools(max_threads, Tristate::kDefault,
|
||||
BoundedSlice(0, max_packages));
|
||||
|
|
@ -195,17 +204,14 @@ void BenchAllMatMul() {
|
|||
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
|
||||
pools.TopologyString(), pools.PinString());
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
pools.MaybeStartSpinning(use_spinning);
|
||||
Allocator::Init(pools.Topology());
|
||||
MatMulEnv env(pools);
|
||||
|
||||
for (size_t batch_size : {1, /* 4, 128,*/ 512}) {
|
||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||
constexpr bool kAdd = false;
|
||||
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
|
||||
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
|
||||
}
|
||||
pools.MaybeStopSpinning(use_spinning);
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
|
|
@ -31,7 +32,6 @@
|
|||
#include "compression/compress-inl.h"
|
||||
#include "ops/fp_arith-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/stats.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
|
|
@ -44,7 +45,6 @@
|
|||
// After highway.h
|
||||
#include "compression/test_util-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
24
ops/matmul.h
24
ops/matmul.h
|
|
@ -65,22 +65,36 @@ struct CacheSizes {
|
|||
size_t l3_bytes;
|
||||
};
|
||||
|
||||
class MMParallel {
|
||||
public:
|
||||
MMParallel() : pools_(nullptr) {}
|
||||
explicit MMParallel(NestedPools& pools) : pools_(&pools) {}
|
||||
|
||||
NestedPools& Pools() const { return *pools_; }
|
||||
hwy::ThreadPool& Pool() const { return pools_->Pool(); }
|
||||
|
||||
private:
|
||||
NestedPools* pools_;
|
||||
};
|
||||
|
||||
// Allocations and threads, shared across MatMul calls.
|
||||
class MatMulEnv {
|
||||
public:
|
||||
MatMulEnv() : pools_(nullptr) {}
|
||||
explicit MatMulEnv(NestedPools& pools) : pools_(&pools) {
|
||||
explicit MatMulEnv(NestedPools& pools) : parallel(pools) {
|
||||
const size_t N = hwy::VectorBytes() / sizeof(float);
|
||||
buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
|
||||
}
|
||||
|
||||
RowVectorBatch<float>& Buf() { return buf_; }
|
||||
NestedPools& Pools() const { return *pools_; }
|
||||
hwy::ThreadPool& Pool() const { return pools_->Pool(); }
|
||||
|
||||
MMParallel parallel;
|
||||
|
||||
// TODO: remove once no longer used.
|
||||
NestedPools& Pools() const { return parallel.Pools(); }
|
||||
hwy::ThreadPool& Pool() const { return parallel.Pool(); }
|
||||
|
||||
private:
|
||||
RowVectorBatch<float> buf_;
|
||||
NestedPools* pools_;
|
||||
};
|
||||
|
||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||
|
|
|
|||
|
|
@ -137,8 +137,8 @@ float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
|||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
template <typename TA, typename TB>
|
||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
const RowPtrF& C_slow, const RowPtrF& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t num_a = A.extents.Area();
|
||||
|
|
@ -160,13 +160,13 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
|||
MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
|
||||
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
||||
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
double tolerance = 8 * norm * eps_f32;
|
||||
double tolerance = 10 * norm * eps_f32;
|
||||
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
||||
// tolerance there.
|
||||
if (IsF32<MatTA>() && !IsF32<MatTB>()) {
|
||||
if (IsF32<TA>() && IsF32<TB>()) {
|
||||
tolerance += 4 * max_abs * eps_bf16;
|
||||
}
|
||||
if (tolerance > 8.0) {
|
||||
if (tolerance > 500.0) {
|
||||
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
||||
}
|
||||
|
||||
|
|
@ -189,23 +189,22 @@ void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
|||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
||||
template <typename TA, typename TB>
|
||||
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||
const RowPtrF& C) {
|
||||
// MatTA can be any Packed except NuqStream because it uses pointer
|
||||
// TA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
const float scale = A.scale * B.scale;
|
||||
|
||||
const hn::ScalableTag<float> df; // lane type is ignored
|
||||
const PackedSpan<const MatTB> b_span =
|
||||
MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
||||
const PackedSpan<const TB> b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
||||
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||
const IndexRange all_cols_c(0, C.Cols());
|
||||
|
||||
NestedPools& pools = env.Pools();
|
||||
NestedPools& pools = env.parallel.Pools();
|
||||
hwy::ThreadPool& all_packages = pools.AllPackages();
|
||||
const IndexRangePartition get_row_c =
|
||||
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
|
||||
|
|
@ -213,7 +212,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
|||
get_row_c, all_packages,
|
||||
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
||||
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
||||
const size_t multiple = Allocator::QuantumBytes() / sizeof(MatTB);
|
||||
const size_t multiple = Allocator::QuantumBytes() / sizeof(TB);
|
||||
const IndexRangePartition get_col_c =
|
||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||
ParallelizeOneRange(
|
||||
|
|
@ -240,20 +239,19 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
|||
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
||||
}
|
||||
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
template <typename TA, typename TB = TA>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
||||
TypeName<MatTB>());
|
||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
||||
fprintf(stderr, "TestMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
|
||||
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
|
||||
|
||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
||||
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
||||
RowVectorBatch<float> c_batch(C_extents);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
|
|
@ -303,8 +301,13 @@ void TestTiny() {
|
|||
Allocator::Init(pools.Topology());
|
||||
MatMulEnv env(pools);
|
||||
|
||||
for (size_t batch_size = 1; batch_size <= 3 * kRegRows; ++batch_size) {
|
||||
TestMatMul<F32, F32>(batch_size, 256, 256, /*add=*/false, env);
|
||||
for (size_t M = 1; M <= 3 * kRegRows; ++M) {
|
||||
for (size_t K = 64; K <= 128; K *= 2) {
|
||||
for (size_t N = /*kRegRows*/ 16; N <= 64;
|
||||
N += max_packages * kRegRows) {
|
||||
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
|
||||
}
|
||||
}
|
||||
}
|
||||
pools.MaybeStopSpinning(use_spinning);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_
|
||||
|
||||
|
|
@ -38,7 +39,6 @@
|
|||
#include "ops/dot-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/profiler.h"
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
@ -47,7 +48,6 @@
|
|||
#include "ops/sum-inl.h"
|
||||
#include "hwy/contrib/algo/transform-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
||||
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
|
||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const double freq_exponents =
|
||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||
// noticeably.
|
||||
inv_timescale.Batch(0)[dim] =
|
||||
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||
}
|
||||
return inv_timescale;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
||||
|
|
@ -18,6 +18,8 @@
|
|||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include "ops/ops.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -30,12 +32,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/compress.h" // BF16
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -407,8 +407,9 @@ void TestRopeAndMulBy() {
|
|||
std::vector<float> qactual(dim_qkv);
|
||||
std::vector<float> kexpected(dim_qkv);
|
||||
std::vector<float> kactual(dim_qkv);
|
||||
RowVectorBatch<float> inv_timescale = gcpp::Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
RowVectorBatch<float> inv_timescale = gcpp::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
#include "util/basics.h" // MaybeCheckInitialized
|
||||
|
|
@ -76,31 +75,42 @@ size_t DetectPageSize() {
|
|||
|
||||
static size_t line_bytes_;
|
||||
static size_t vector_bytes_;
|
||||
static size_t step_bytes_;
|
||||
static size_t quantum_bytes_;
|
||||
static size_t quantum_steps_;
|
||||
static size_t l1_bytes_;
|
||||
static size_t l2_bytes_;
|
||||
static size_t l3_bytes_;
|
||||
static bool should_bind_ = false;
|
||||
|
||||
size_t Allocator::LineBytes() { return line_bytes_; }
|
||||
size_t Allocator::VectorBytes() { return vector_bytes_; }
|
||||
size_t Allocator::StepBytes() { return step_bytes_; }
|
||||
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
|
||||
size_t Allocator::QuantumSteps() { return quantum_steps_; }
|
||||
size_t Allocator::L1Bytes() { return l1_bytes_; }
|
||||
size_t Allocator::L2Bytes() { return l2_bytes_; }
|
||||
size_t Allocator::L3Bytes() { return l3_bytes_; }
|
||||
bool Allocator::ShouldBind() { return should_bind_; }
|
||||
|
||||
void Allocator::Init(const BoundedTopology& topology) {
|
||||
line_bytes_ = DetectLineBytes();
|
||||
vector_bytes_ = hwy::VectorBytes();
|
||||
quantum_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); // may overwrite below
|
||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
||||
quantum_bytes_ = step_bytes_; // may overwrite below
|
||||
|
||||
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
|
||||
if (const hwy::Cache* caches = hwy::DataCaches()) {
|
||||
l1_bytes_ = caches[1].size_kib << 10;
|
||||
l2_bytes_ = caches[2].size_kib << 10;
|
||||
l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing;
|
||||
} else { // Unknown, make reasonable assumptions.
|
||||
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
|
||||
l1_bytes_ = 32 << 10;
|
||||
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
|
||||
}
|
||||
if (l3_bytes_ == 0) {
|
||||
l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10;
|
||||
}
|
||||
|
||||
// Prerequisites for binding:
|
||||
// - supported by the OS (currently Linux only),
|
||||
|
|
@ -115,9 +125,15 @@ void Allocator::Init(const BoundedTopology& topology) {
|
|||
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
||||
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
||||
quantum_bytes_ = page_bytes;
|
||||
// Ensure MaxQuantumBytes() is an upper bound.
|
||||
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
|
||||
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
|
||||
should_bind_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
|
||||
quantum_steps_ = quantum_bytes_ / step_bytes_;
|
||||
}
|
||||
|
||||
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
|
||||
|
|
|
|||
|
|
@ -72,12 +72,17 @@ class Allocator {
|
|||
static size_t LineBytes();
|
||||
// Bytes per full vector. Used to compute loop steps.
|
||||
static size_t VectorBytes();
|
||||
// Granularity of regions processed by different threads. Their start and
|
||||
// length of regions should be divisible by this, which is at least
|
||||
// `HWY_MAX(LineBytes(), VectorBytes())`.
|
||||
// Work granularity that avoids false sharing and partial vectors.
|
||||
static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes())
|
||||
// Granularity like `StepBytes()`, but when NUMA may be involved.
|
||||
static size_t QuantumBytes();
|
||||
// Upper bound on `QuantumBytes()`, for stack allocations.
|
||||
static constexpr size_t MaxQuantumBytes() { return 4096; }
|
||||
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
|
||||
|
||||
static size_t L1Bytes();
|
||||
static size_t L2Bytes();
|
||||
static size_t L3Bytes();
|
||||
|
||||
// Returns pointer aligned to `QuantumBytes()`.
|
||||
template <typename T>
|
||||
|
|
@ -192,10 +197,9 @@ class RowPtr {
|
|||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
stride_(stride),
|
||||
step_(static_cast<uint32_t>(
|
||||
HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))),
|
||||
step_(static_cast<uint32_t>(Allocator::StepBytes())),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
row_mask_(Allocator::QuantumBytes() / step_ - 1) {
|
||||
row_mask_(Allocator::QuantumSteps() - 1) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
HWY_DASSERT(row_mask_ != ~size_t{0});
|
||||
row_mask_ = 0; // TODO: remove
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ struct IndexRange {
|
|||
IndexRange& operator=(const IndexRange& other) = default;
|
||||
|
||||
size_t Num() const { return end_ - begin_; }
|
||||
bool Contains(IndexRange other) const {
|
||||
return other.begin_ >= begin_ && other.end_ <= end_;
|
||||
}
|
||||
|
||||
// Enable range-based for loops.
|
||||
class Iterator {
|
||||
|
|
|
|||
|
|
@ -89,21 +89,27 @@ class Pinning {
|
|||
|
||||
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
|
||||
// and sets `any_error_` if any fails.
|
||||
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
||||
if (HWY_UNLIKELY(!want_pin_)) return;
|
||||
|
||||
void MaybePin(size_t pkg_idx, size_t cluster_idx,
|
||||
const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(pool->NumWorkers() <= lps.size());
|
||||
pool->Run(
|
||||
0, pool->NumWorkers(),
|
||||
[this, &pool, &lps](uint64_t task, size_t thread) {
|
||||
pool->Run(0, pool->NumWorkers(), [&](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
|
||||
char buf[16]; // Linux limitation
|
||||
const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03zu",
|
||||
pkg_idx, cluster_idx, task);
|
||||
HWY_ASSERT(bytes_written < sizeof(buf));
|
||||
hwy::SetThreadName(buf, 0); // does not support varargs
|
||||
|
||||
if (HWY_LIKELY(want_pin_)) {
|
||||
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
|
||||
fprintf(stderr,
|
||||
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
|
||||
task, pool->NumWorkers(), lps[task], lps.size());
|
||||
(void)any_error_.test_and_set();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -466,7 +472,8 @@ NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
|
|||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||
// Pin workers AND the calling thread from `all_clusters`.
|
||||
GetPinning().MaybePin(cluster, clusters_[cluster_idx]);
|
||||
GetPinning().MaybePin(pkg_idx, cluster_idx, cluster,
|
||||
clusters_[cluster_idx]);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -390,6 +390,25 @@ class IndexRangePartition {
|
|||
TaskSize());
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void VisitAll(const Func& func) const {
|
||||
for (size_t task_idx = 0; task_idx < NumTasks(); ++task_idx) {
|
||||
func(Range(task_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void VisitFirst(const Func& func) const {
|
||||
func(Range(0));
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void VisitRemaining(const Func& func) const {
|
||||
for (size_t task_idx = 1; task_idx < NumTasks(); ++task_idx) {
|
||||
func(Range(task_idx));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
IndexRange range_;
|
||||
uint32_t task_size_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue