mirror of https://github.com/google/gemma.cpp.git
De-singleton ThreadingContext so callers can pass in their own
weights.cc: fix BindB argument for bf16 tensors threading_test: enable autotune PiperOrigin-RevId: 785763618
This commit is contained in:
parent
5474146129
commit
e76e29ce11
|
|
@ -389,6 +389,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":mat",
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,9 @@ class SbsWriterImpl : public ISbsWriter {
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void InsertT(const char* name, F32Span weights,
|
void InsertT(const char* name, F32Span weights,
|
||||||
const TensorInfo& tensor_info) {
|
const TensorInfo& tensor_info) {
|
||||||
|
// TODO(janwas): 1D parallel-for.
|
||||||
|
hwy::ThreadPool& pool = ctx_.pools.Pool();
|
||||||
|
|
||||||
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
|
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
|
||||||
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
|
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
|
||||||
// and depending on the input values may require rescaling. Scaling is
|
// and depending on the input values may require rescaling. Scaling is
|
||||||
|
|
@ -73,13 +76,13 @@ class SbsWriterImpl : public ISbsWriter {
|
||||||
|
|
||||||
mat.AppendTo(serialized_mat_ptrs_);
|
mat.AppendTo(serialized_mat_ptrs_);
|
||||||
mat_owners_.push_back(MatOwner());
|
mat_owners_.push_back(MatOwner());
|
||||||
mat_owners_.back().AllocateFor(mat, MatPadding::kPacked);
|
mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
|
||||||
|
|
||||||
// Handle gemma_export_test's MockArray. Write blobs so that the test
|
// Handle gemma_export_test's MockArray. Write blobs so that the test
|
||||||
// succeeds, but we only have 10 floats, not the full tensor.
|
// succeeds, but we only have 10 floats, not the full tensor.
|
||||||
if (weights.size() == 10 && mat.Extents().Area() != 10) {
|
if (weights.size() == 10 && mat.Extents().Area() != 10) {
|
||||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||||
/*packed_ofs=*/0, pool_);
|
/*packed_ofs=*/0, pool);
|
||||||
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
|
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -89,12 +92,12 @@ class SbsWriterImpl : public ISbsWriter {
|
||||||
TypeName(TypeEnum<Packed>()));
|
TypeName(TypeEnum<Packed>()));
|
||||||
HWY_ASSERT(weights.size() == mat.Extents().Area());
|
HWY_ASSERT(weights.size() == mat.Extents().Area());
|
||||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||||
/*packed_ofs=*/0, pool_);
|
/*packed_ofs=*/0, pool);
|
||||||
writer_.Add(name, mat.Packed(), mat.PackedBytes());
|
writer_.Add(name, mat.Packed(), mat.PackedBytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {}
|
SbsWriterImpl() : ctx_(ThreadingArgs()) {}
|
||||||
|
|
||||||
void Insert(const char* name, F32Span weights, Type type,
|
void Insert(const char* name, F32Span weights, Type type,
|
||||||
const TensorInfo& tensor_info) override {
|
const TensorInfo& tensor_info) override {
|
||||||
|
|
@ -122,18 +125,18 @@ class SbsWriterImpl : public ISbsWriter {
|
||||||
const GemmaTokenizer tokenizer(
|
const GemmaTokenizer tokenizer(
|
||||||
tokenizer_path.empty() ? kMockTokenizer
|
tokenizer_path.empty() ? kMockTokenizer
|
||||||
: ReadFileToString(Path(tokenizer_path)));
|
: ReadFileToString(Path(tokenizer_path)));
|
||||||
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, pool_,
|
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_,
|
||||||
gcpp::Path(path));
|
ctx_.pools.Pool(), gcpp::Path(path));
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool& pool_;
|
ThreadingContext ctx_;
|
||||||
std::vector<MatOwner> mat_owners_;
|
std::vector<MatOwner> mat_owners_;
|
||||||
CompressWorkingSet working_set_;
|
CompressWorkingSet working_set_;
|
||||||
BlobWriter writer_;
|
BlobWriter writer_;
|
||||||
std::vector<uint32_t> serialized_mat_ptrs_;
|
std::vector<uint32_t> serialized_mat_ptrs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl; }
|
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); }
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -69,12 +69,13 @@ void ForeachPackedAndRawType() {
|
||||||
|
|
||||||
// Generates inputs: deterministic, within max SfpStream range.
|
// Generates inputs: deterministic, within max SfpStream range.
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
||||||
|
const Allocator& allocator, MatPadding padding,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
ws.tls.resize(pool.NumWorkers());
|
ws.tls.resize(pool.NumWorkers());
|
||||||
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||||
MatStorageT<MatT> compressed("mat", extents, padding);
|
MatStorageT<MatT> compressed("mat", extents, allocator, padding);
|
||||||
const float scale = SfpStream::kMax / extents.Area();
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||||
float* HWY_RESTRICT row = raw.Row(r);
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
|
|
@ -95,12 +96,13 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
||||||
// Same, but `extents` describes the transposed matrix.
|
// Same, but `extents` describes the transposed matrix.
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
|
const Allocator& allocator,
|
||||||
MatPadding padding,
|
MatPadding padding,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
ws.tls.resize(pool.NumWorkers());
|
ws.tls.resize(pool.NumWorkers());
|
||||||
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||||
MatStorageT<MatT> compressed("trans", extents, padding);
|
MatStorageT<MatT> compressed("trans", extents, allocator, padding);
|
||||||
const float scale = SfpStream::kMax / extents.Area();
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||||
float* HWY_RESTRICT row = raw.Row(r);
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
|
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(),
|
||||||
|
env.MutableEnv().ctx.allocator);
|
||||||
float entropy =
|
float entropy =
|
||||||
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
||||||
env.MutableEnv(), env.Verbosity());
|
env.MutableEnv(), env.Verbosity());
|
||||||
|
|
|
||||||
|
|
@ -50,14 +50,13 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: env_(MakeMatMulEnv(threading, inference)),
|
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||||
gemma_(loader, inference, env_.ctx.pools) {
|
|
||||||
const ModelConfig& config = gemma_.GetModelConfig();
|
const ModelConfig& config = gemma_.GetModelConfig();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference));
|
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||||
|
|
||||||
if (inference.verbosity >= 2) {
|
if (inference.verbosity >= 2) {
|
||||||
ShowConfig(loader, threading, inference, config);
|
ShowConfig(loader, threading, inference, config, ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
InitGenerator(inference, gen_);
|
InitGenerator(inference, gen_);
|
||||||
|
|
@ -141,7 +140,8 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
|
|
||||||
// Ensure we have at least one KVCache per query.
|
// Ensure we have at least one KVCache per query.
|
||||||
while (kv_caches_.size() < num_queries) {
|
while (kv_caches_.size() < num_queries) {
|
||||||
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
kv_caches_.push_back(
|
||||||
|
KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator));
|
||||||
}
|
}
|
||||||
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||||
|
|
||||||
|
|
@ -228,7 +228,8 @@ static constexpr const char* CompiledConfig() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference, const ModelConfig& config) {
|
const InferenceArgs& inference, const ModelConfig& config,
|
||||||
|
const ThreadingContext& ctx) {
|
||||||
threading.Print(inference.verbosity);
|
threading.Print(inference.verbosity);
|
||||||
loader.Print(inference.verbosity);
|
loader.Print(inference.verbosity);
|
||||||
inference.Print(inference.verbosity);
|
inference.Print(inference.verbosity);
|
||||||
|
|
@ -241,7 +242,6 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
char cpu100[100] = "unknown";
|
char cpu100[100] = "unknown";
|
||||||
(void)hwy::platform::GetCpuString(cpu100);
|
(void)hwy::platform::GetCpuString(cpu100);
|
||||||
const ThreadingContext& ctx = ThreadingContext::Get();
|
|
||||||
|
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"Date & Time : %s" // dt includes \n
|
"Date & Time : %s" // dt includes \n
|
||||||
|
|
|
||||||
|
|
@ -49,9 +49,6 @@ class GemmaEnv {
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv);
|
||||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference);
|
const InferenceArgs& inference);
|
||||||
// Avoid memory leaks in test.
|
|
||||||
~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); }
|
|
||||||
|
|
||||||
MatMulEnv& Env() { return env_; }
|
MatMulEnv& Env() { return env_; }
|
||||||
|
|
||||||
size_t MaxGeneratedTokens() const {
|
size_t MaxGeneratedTokens() const {
|
||||||
|
|
@ -115,6 +112,7 @@ class GemmaEnv {
|
||||||
MatMulEnv& MutableEnv() { return env_; }
|
MatMulEnv& MutableEnv() { return env_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
ThreadingContext ctx_;
|
||||||
MatMulEnv env_;
|
MatMulEnv env_;
|
||||||
Gemma gemma_;
|
Gemma gemma_;
|
||||||
std::mt19937 gen_; // Random number generator.
|
std::mt19937 gen_; // Random number generator.
|
||||||
|
|
@ -126,7 +124,8 @@ class GemmaEnv {
|
||||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||||
|
|
||||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference, const ModelConfig& config);
|
const InferenceArgs& inference, const ModelConfig& config,
|
||||||
|
const ThreadingContext& ctx);
|
||||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference);
|
const InferenceArgs& inference);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,10 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
|
gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
|
||||||
gcpp::Gemma gemma(loader, inference, env.ctx.pools);
|
gcpp::MatMulEnv env(ctx);
|
||||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
gcpp::Gemma gemma(loader, inference, ctx);
|
||||||
|
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,10 @@ class SimplifiedGemma {
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||||
: env_(MakeMatMulEnv(threading, inference)),
|
: ctx_(UpdateArgs(threading, inference)),
|
||||||
gemma_(loader, inference, env_.ctx.pools),
|
env_(ctx_),
|
||||||
kv_cache_(gemma_.GetModelConfig(), inference) {
|
gemma_(loader, inference, ctx_),
|
||||||
|
kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) {
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen_.seed(rd());
|
gen_.seed(rd());
|
||||||
|
|
@ -88,6 +89,7 @@ class SimplifiedGemma {
|
||||||
~SimplifiedGemma() = default;
|
~SimplifiedGemma() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
gcpp::ThreadingContext ctx_;
|
||||||
gcpp::MatMulEnv env_;
|
gcpp::MatMulEnv env_;
|
||||||
gcpp::Gemma gemma_;
|
gcpp::Gemma gemma_;
|
||||||
gcpp::KVCache kv_cache_;
|
gcpp::KVCache kv_cache_;
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,15 @@ namespace gcpp {
|
||||||
|
|
||||||
struct GriffinActivations {
|
struct GriffinActivations {
|
||||||
GriffinActivations(const ModelConfig& config, size_t batch_size,
|
GriffinActivations(const ModelConfig& config, size_t batch_size,
|
||||||
MatPadding pad)
|
const Allocator& allocator)
|
||||||
: griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad),
|
: griffin_x(
|
||||||
griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad),
|
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
|
||||||
griffin_gate_x("griffin_gate_x",
|
griffin_y(
|
||||||
Extents2D(batch_size, config.model_dim), pad),
|
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
|
||||||
griffin_multiplier("griffin_mul",
|
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
|
||||||
Extents2D(batch_size, config.model_dim), pad) {}
|
config.model_dim, allocator)),
|
||||||
|
griffin_multiplier(MatFactory("griffin_mul", batch_size,
|
||||||
|
config.model_dim, allocator)) {}
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
if (griffin_x.Rows() == 0) return;
|
if (griffin_x.Rows() == 0) return;
|
||||||
|
|
@ -70,34 +72,34 @@ struct AttentionActivations {
|
||||||
|
|
||||||
AttentionActivations(
|
AttentionActivations(
|
||||||
const ModelConfig& config, const LayerConfig& layer_config,
|
const ModelConfig& config, const LayerConfig& layer_config,
|
||||||
size_t batch_size, size_t seq_len, MatPadding pad,
|
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: config(config),
|
: config(config),
|
||||||
|
|
||||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||||
// and does not use an external KV cache.
|
// and does not use an external KV cache.
|
||||||
q("q",
|
q(MatFactory("q", batch_size,
|
||||||
Extents2D(batch_size,
|
config.vocab_size == 0
|
||||||
config.vocab_size == 0
|
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
: layer_config.heads * layer_config.qkv_dim,
|
||||||
: layer_config.heads * layer_config.qkv_dim),
|
allocator)),
|
||||||
pad),
|
|
||||||
|
|
||||||
pre_att_rms_out("pre_att_rms_out",
|
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||||
Extents2D(batch_size, config.model_dim), pad),
|
config.model_dim, allocator)),
|
||||||
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad),
|
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
||||||
att_out(
|
allocator)),
|
||||||
"att_out",
|
att_out(MatFactory("att_out", batch_size,
|
||||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
layer_config.heads * layer_config.qkv_dim,
|
||||||
pad),
|
allocator)),
|
||||||
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad),
|
att_sums(
|
||||||
|
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
|
||||||
|
|
||||||
inv_timescale(
|
inv_timescale(
|
||||||
CreateInvTimescale(layer_config.qkv_dim,
|
CreateInvTimescale(allocator, layer_config.qkv_dim,
|
||||||
layer_config.post_qk == PostQKType::HalfRope)),
|
layer_config.post_qk == PostQKType::HalfRope)),
|
||||||
inv_timescale_global(CreateInvTimescale(
|
inv_timescale_global(CreateInvTimescale(
|
||||||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
allocator, layer_config.qkv_dim,
|
||||||
1000000.0)),
|
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||||
|
|
||||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||||
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
||||||
|
|
@ -149,21 +151,23 @@ struct AttentionActivations {
|
||||||
|
|
||||||
struct Activations {
|
struct Activations {
|
||||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||||
|
const Allocator& allocator,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: layer_config(config.layer_configs[0]),
|
: layer_config(config.layer_configs[0]),
|
||||||
|
|
||||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
||||||
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
|
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
||||||
|
|
||||||
pre_ffw_rms_out("pre_ffw_rms_out",
|
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||||
Extents2D(batch_size, config.model_dim), pad_),
|
config.model_dim, allocator)),
|
||||||
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
|
||||||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
|
||||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
|
||||||
|
|
||||||
attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs),
|
attention(config, layer_config, batch_size, seq_len, allocator,
|
||||||
|
row_ptrs),
|
||||||
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
||||||
pad_) {
|
allocator) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
|
|
@ -193,15 +197,14 @@ struct Activations {
|
||||||
}
|
}
|
||||||
|
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
const Extents2D none_ = Extents2D();
|
|
||||||
const MatPadding pad_ = MatPadding::kOdd;
|
|
||||||
|
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
MatStorageT<float> logits;
|
MatStorageT<float> logits;
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
MatStorageT<BF16> pre_ffw_rms_out;
|
MatStorageT<BF16> pre_ffw_rms_out;
|
||||||
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
|
// Norm may be large, so prefer to keep as f32.
|
||||||
|
MatStorageT<float> C1;
|
||||||
MatStorageT<float> C2;
|
MatStorageT<float> C2;
|
||||||
MatStorageT<BF16> ffw_out;
|
MatStorageT<BF16> ffw_out;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,10 @@ namespace gcpp {
|
||||||
|
|
||||||
// ConversationData constructor implementation
|
// ConversationData constructor implementation
|
||||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||||
const InferenceArgs& inference_args)
|
const InferenceArgs& inference_args,
|
||||||
: kv_cache(std::make_unique<KVCache>(model_config, inference_args)),
|
const Allocator& allocator)
|
||||||
|
: kv_cache(
|
||||||
|
std::make_unique<KVCache>(model_config, inference_args, allocator)),
|
||||||
abs_pos(0) {}
|
abs_pos(0) {}
|
||||||
|
|
||||||
// ConversationData copy constructor implementation
|
// ConversationData copy constructor implementation
|
||||||
|
|
@ -101,15 +103,16 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
int max_generated_tokens)
|
int max_generated_tokens)
|
||||||
: inference_args(inference_args),
|
: inference_args(inference_args),
|
||||||
threading_args(threading_args),
|
threading_args(threading_args),
|
||||||
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
|
ctx(UpdateArgs(threading_args, inference_args)),
|
||||||
|
matmul_env(ctx),
|
||||||
active_conversation_name("default"),
|
active_conversation_name("default"),
|
||||||
model(loader, inference_args, matmul_env.ctx.pools) {
|
model(loader, inference_args, matmul_env.ctx) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
LogDebug("Creating initial ConversationData");
|
LogDebug("Creating initial ConversationData");
|
||||||
// Create the initial ConversationData object using make_shared
|
// Create the initial ConversationData object using make_shared
|
||||||
active_conversation = std::make_shared<ConversationData>(
|
active_conversation = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args);
|
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||||
|
|
||||||
LogDebug(
|
LogDebug(
|
||||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||||
|
|
@ -188,7 +191,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
|
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||||
model_config.model_dim)
|
model_config.model_dim)
|
||||||
: Extents2D(0, 0),
|
: Extents2D(0, 0),
|
||||||
MatPadding::kOdd);
|
ctx.allocator, MatPadding::kOdd);
|
||||||
if (image_data != nullptr) {
|
if (image_data != nullptr) {
|
||||||
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
|
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
model_config.wrapping == PromptWrapping::GEMMA_VLM);
|
model_config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@ namespace gcpp {
|
||||||
// Struct to hold data for a single conversation thread
|
// Struct to hold data for a single conversation thread
|
||||||
struct ConversationData {
|
struct ConversationData {
|
||||||
ConversationData(const ModelConfig& model_config,
|
ConversationData(const ModelConfig& model_config,
|
||||||
const InferenceArgs& inference_args);
|
const InferenceArgs& inference_args,
|
||||||
|
const Allocator& allocator);
|
||||||
ConversationData(const ConversationData& other);
|
ConversationData(const ConversationData& other);
|
||||||
|
|
||||||
std::unique_ptr<KVCache> kv_cache;
|
std::unique_ptr<KVCache> kv_cache;
|
||||||
|
|
@ -178,8 +179,8 @@ class GemmaContext {
|
||||||
// rewind to initial state.
|
// rewind to initial state.
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
// Replace the cache within the current ConversationData object
|
// Replace the cache within the current ConversationData object
|
||||||
active_conversation->kv_cache =
|
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||||
std::make_unique<KVCache>(model.GetModelConfig(), inference_args);
|
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||||
|
|
||||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -197,7 +198,7 @@ class GemmaContext {
|
||||||
LogDebug("Creating new conversation");
|
LogDebug("Creating new conversation");
|
||||||
// Create a new ConversationData object using make_shared
|
// Create a new ConversationData object using make_shared
|
||||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args);
|
model.GetModelConfig(), inference_args, ctx.allocator);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -280,6 +281,7 @@ class GemmaContext {
|
||||||
// Cached args (remain global for the context)
|
// Cached args (remain global for the context)
|
||||||
InferenceArgs inference_args;
|
InferenceArgs inference_args;
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx;
|
||||||
MatMulEnv matmul_env;
|
MatMulEnv matmul_env;
|
||||||
|
|
||||||
std::string active_conversation_name;
|
std::string active_conversation_name;
|
||||||
|
|
|
||||||
|
|
@ -537,7 +537,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
kv_cache.SeqLen(), env.row_ptrs);
|
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
|
||||||
|
|
||||||
AllQueries all_queries(prompt, pos, prefix_end,
|
AllQueries all_queries(prompt, pos, prefix_end,
|
||||||
hwy::Span<KVCache>(&kv_cache, 1));
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
|
|
@ -555,7 +555,8 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
runtime_config.prefill_tbatch_size);
|
runtime_config.prefill_tbatch_size);
|
||||||
Activations activations(config, max_batch_size,
|
Activations activations(config, max_batch_size,
|
||||||
all_queries[0].kv_cache.SeqLen(), env.row_ptrs);
|
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
|
||||||
|
env.row_ptrs);
|
||||||
|
|
||||||
for (size_t start = 0; start < all_queries.NumQueries();
|
for (size_t start = 0; start < all_queries.NumQueries();
|
||||||
start += runtime_config.decode_qbatch_size) {
|
start += runtime_config.decode_qbatch_size) {
|
||||||
|
|
@ -579,7 +580,7 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
||||||
env.row_ptrs);
|
env.ctx.allocator, env.row_ptrs);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations, env);
|
prefill_activations, env);
|
||||||
|
|
@ -596,28 +597,14 @@ HWY_EXPORT(GenerateSingleT);
|
||||||
HWY_EXPORT(GenerateBatchT);
|
HWY_EXPORT(GenerateBatchT);
|
||||||
HWY_EXPORT(GenerateImageTokensT);
|
HWY_EXPORT(GenerateImageTokensT);
|
||||||
|
|
||||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
|
||||||
const InferenceArgs& inference_args) {
|
|
||||||
if (inference_args.decode_qbatch_size >= 256) {
|
|
||||||
ThreadingArgs copy = threading_args;
|
|
||||||
copy.max_packages = 1;
|
|
||||||
ThreadingContext::SetArgs(copy);
|
|
||||||
} else {
|
|
||||||
ThreadingContext::SetArgs(threading_args);
|
|
||||||
}
|
|
||||||
|
|
||||||
return MatMulEnv(ThreadingContext::Get());
|
|
||||||
}
|
|
||||||
|
|
||||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
NestedPools& pools)
|
ThreadingContext& ctx)
|
||||||
: reader_(loader.weights),
|
: reader_(loader.weights),
|
||||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config()),
|
weights_(model_.Config()),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||||
inference_(inference) {
|
inference_(inference) {
|
||||||
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx);
|
||||||
pools.Pool());
|
|
||||||
reader_.CloseFile();
|
reader_.CloseFile();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -225,18 +225,15 @@ struct TimingInfo {
|
||||||
size_t tokens_generated = 0;
|
size_t tokens_generated = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the `MatMulEnv` after calling `SetArgs`.
|
|
||||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
|
||||||
const InferenceArgs& inference_args);
|
|
||||||
|
|
||||||
// After construction, all methods are const and thread-compatible if using
|
// After construction, all methods are const and thread-compatible if using
|
||||||
// separate MatMulEnv for each thread.
|
// separate ThreadingContext for each thread.
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||||
// `pools` are used to parallelize loading.
|
// `ctx` is only used to read tensors, but it is typically also referenced
|
||||||
|
// by the `MatMulEnv` passed to the Generate* methods.
|
||||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
NestedPools& pools);
|
ThreadingContext& ctx);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
// TODO: rename to Config()
|
// TODO: rename to Config()
|
||||||
|
|
|
||||||
|
|
@ -256,6 +256,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args,
|
||||||
|
const InferenceArgs& inference_args) {
|
||||||
|
if (inference_args.decode_qbatch_size >= 256) {
|
||||||
|
ThreadingArgs copy = threading_args;
|
||||||
|
copy.max_packages = 1;
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
return threading_args;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
|
|
|
||||||
|
|
@ -57,20 +57,24 @@ static size_t CappedSeqLen(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
KVCache::KVCache(const Extents2D& conv1d_extents,
|
KVCache::KVCache(const Extents2D& conv1d_extents,
|
||||||
const Extents2D& rglru_extents, const Extents2D& kv_extents)
|
const Extents2D& rglru_extents, const Extents2D& kv_extents,
|
||||||
: conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd),
|
const Allocator& allocator)
|
||||||
rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd),
|
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
|
||||||
kv_cache("kv", kv_extents, MatPadding::kOdd) {}
|
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
|
||||||
|
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||||
|
allocator_(allocator) {}
|
||||||
|
|
||||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args)
|
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||||
: KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
const Allocator& allocator)
|
||||||
Extents2D(GriffinLayers(config), config.model_dim),
|
: KVCache(
|
||||||
Extents2D(CappedSeqLen(config, inference_args),
|
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
||||||
config.KVCacheCols())) {}
|
Extents2D(GriffinLayers(config), config.model_dim),
|
||||||
|
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||||
|
allocator) {}
|
||||||
|
|
||||||
KVCache KVCache::Copy() {
|
KVCache KVCache::Copy() {
|
||||||
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
||||||
kv_cache.Extents());
|
kv_cache.Extents(), allocator_);
|
||||||
|
|
||||||
if (conv1d_cache.Rows() != 0) {
|
if (conv1d_cache.Rows() != 0) {
|
||||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ namespace gcpp {
|
||||||
using KV_t = float;
|
using KV_t = float;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
|
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||||
|
const Allocator& allocator);
|
||||||
|
|
||||||
// Returns a deep copy of the KVCache. Use explicit function instead of
|
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||||
// copy ctor to make the cost explicit.
|
// copy ctor to make the cost explicit.
|
||||||
|
|
@ -47,9 +48,11 @@ struct KVCache {
|
||||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
const Allocator& allocator_;
|
||||||
|
|
||||||
// For use by other ctor and Copy()
|
// For use by other ctor and Copy()
|
||||||
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
||||||
const Extents2D& kv_extents);
|
const Extents2D& kv_extents, const Allocator& allocator);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -110,7 +110,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
|
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||||
config.model_dim)
|
config.model_dim)
|
||||||
: Extents2D(0, 0),
|
: Extents2D(0, 0),
|
||||||
MatPadding::kOdd);
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
|
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
|
|
@ -254,10 +254,11 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference) {
|
const InferenceArgs& inference) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
MatMulEnv env(MakeMatMulEnv(threading, inference));
|
ThreadingContext ctx(UpdateArgs(threading, inference));
|
||||||
|
MatMulEnv env(ctx);
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, env.ctx.pools);
|
const Gemma gemma(loader, inference, ctx);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
|
|
@ -284,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
if (inference.IsInteractive()) {
|
if (inference.IsInteractive()) {
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(loader, threading, inference, gemma.GetModelConfig());
|
ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx);
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
10
gemma/vit.cc
10
gemma/vit.cc
|
|
@ -80,11 +80,11 @@ class VitAttention {
|
||||||
|
|
||||||
// Shift Q, K, VT to MatStorageT.
|
// Shift Q, K, VT to MatStorageT.
|
||||||
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
|
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
|
||||||
MatPadding::kPacked);
|
env_.ctx.allocator, MatPadding::kPacked);
|
||||||
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim),
|
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
|
||||||
MatPadding::kPacked);
|
MatPadding::kPacked);
|
||||||
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
|
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
|
||||||
MatPadding::kPacked);
|
env_.ctx.allocator, MatPadding::kPacked);
|
||||||
|
|
||||||
// Initialize att_out to zero prior to head loop.
|
// Initialize att_out to zero prior to head loop.
|
||||||
ZeroInit(activations_.attention.att_out);
|
ZeroInit(activations_.attention.att_out);
|
||||||
|
|
@ -294,7 +294,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
// image_patches is (256, 14 * 14 * 3)
|
// image_patches is (256, 14 * 14 * 3)
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
||||||
MatPadding::kOdd);
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
for (size_t i = 0; i < num_tokens; ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
image.GetPatch(i, image_patches.Row(i));
|
image.GetPatch(i, image_patches.Row(i));
|
||||||
}
|
}
|
||||||
|
|
@ -329,7 +329,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
weights.vit_encoder_norm_bias, activations.x);
|
weights.vit_encoder_norm_bias, activations.x);
|
||||||
|
|
||||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||||
activations.x = AvgPool4x4(activations.x);
|
activations.x = AvgPool4x4(activations.x, env.ctx.allocator);
|
||||||
|
|
||||||
// Apply soft embedding norm before input projection.
|
// Apply soft embedding norm before input projection.
|
||||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,8 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Copies att_weights from `attn_vec_einsum_w`.
|
// Copies att_weights from `attn_vec_einsum_w`.
|
||||||
void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator) {
|
||||||
// We only use this tensor for Gemma layers.
|
// We only use this tensor for Gemma layers.
|
||||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||||
|
|
||||||
|
|
@ -71,7 +72,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
||||||
static std::mutex m;
|
static std::mutex m;
|
||||||
std::lock_guard<std::mutex> lock(m);
|
std::lock_guard<std::mutex> lock(m);
|
||||||
mat_owners.push_back(MatOwner());
|
mat_owners.push_back(MatOwner());
|
||||||
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
|
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kOdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t T_bytes = att_weights.ElementBytes();
|
const size_t T_bytes = att_weights.ElementBytes();
|
||||||
|
|
@ -149,9 +150,10 @@ void LayerWeightsPtrs::SplitAttW1() {
|
||||||
// Must be called after reading weights via `ForEachTensor`.
|
// Must be called after reading weights via `ForEachTensor`.
|
||||||
// TODO: exporters should bake this into the weights already.
|
// TODO: exporters should bake this into the weights already.
|
||||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||||
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners) {
|
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator) {
|
||||||
// TODO(janwas): handle NUQ
|
// TODO(janwas): handle NUQ
|
||||||
InitAttWeights(mat_owners);
|
InitAttWeights(mat_owners, allocator);
|
||||||
SplitW1();
|
SplitW1();
|
||||||
SplitAttW1();
|
SplitAttW1();
|
||||||
}
|
}
|
||||||
|
|
@ -223,13 +225,15 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
||||||
// For reshaping file tensors to the shape expected by the code. This would
|
// For reshaping file tensors to the shape expected by the code. This would
|
||||||
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
||||||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
hwy::ThreadPool& pool) {
|
ThreadingContext& ctx) {
|
||||||
|
// TODO: use 1D parallel-for helper function
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||||
GetLayer(layer)->Fixup(mat_owners);
|
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||||
});
|
});
|
||||||
|
|
||||||
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||||
VitLayer(layer)->Fixup(mat_owners);
|
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -260,12 +264,12 @@ enum class Mode {
|
||||||
|
|
||||||
// Decides whether to read or map based on heuristics and user override.
|
// Decides whether to read or map based on heuristics and user override.
|
||||||
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
|
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
|
||||||
const InferenceArgs& inference) {
|
const InferenceArgs& inference,
|
||||||
|
const Allocator& allocator) {
|
||||||
Tristate to_bf16 = loader.to_bf16;
|
Tristate to_bf16 = loader.to_bf16;
|
||||||
Tristate map = loader.map;
|
Tristate map = loader.map;
|
||||||
|
|
||||||
// Disable mapping if not padded to the base page size.
|
// Disable mapping if not padded to the base page size.
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
if (file_bytes % allocator.BasePageBytes() != 0) {
|
if (file_bytes % allocator.BasePageBytes() != 0) {
|
||||||
if (map == Tristate::kTrue) { // Only complain if explicitly requested.
|
if (map == Tristate::kTrue) { // Only complain if explicitly requested.
|
||||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||||
|
|
@ -321,30 +325,31 @@ struct TensorToRead {
|
||||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||||
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
const Mode mode, std::vector<MatOwner>& owners,
|
const Mode mode, std::vector<MatOwner>& owners,
|
||||||
hwy::ThreadPool& pool) {
|
ThreadingContext& ctx) {
|
||||||
const size_t start = owners.size();
|
const size_t start = owners.size();
|
||||||
owners.resize(start + tensors.size());
|
owners.resize(start + tensors.size());
|
||||||
|
|
||||||
MMParallel parallel(ThreadingContext::Get());
|
MMParallel parallel(ctx);
|
||||||
|
|
||||||
// Allocate in parallel because faulting in large tensors is slow.
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
ctx.pools.Pool().Run(
|
||||||
TensorToRead& tensor = tensors[task];
|
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
MatPtr& mat = *tensor.mat;
|
TensorToRead& tensor = tensors[task];
|
||||||
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
||||||
tensor.prev_type = mat.GetType();
|
tensor.prev_type = mat.GetType();
|
||||||
// We only care about MatMul inputs; skip F32 or small tensors.
|
// We only care about MatMul inputs; skip F32 or small tensors.
|
||||||
if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) {
|
if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) {
|
||||||
tensor.keep_type = true;
|
tensor.keep_type = true;
|
||||||
tensor.padding = MatPadding::kPacked; // single I/O for simplicity
|
tensor.padding = MatPadding::kPacked; // single I/O for simplicity
|
||||||
} else if (mode == Mode::kReadBF16) {
|
} else if (mode == Mode::kReadBF16) {
|
||||||
mat.SetType(Type::kBF16);
|
mat.SetType(Type::kBF16);
|
||||||
}
|
}
|
||||||
|
|
||||||
owners[start + task].AllocateFor(*tensor.mat, tensor.padding);
|
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
||||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
tensor.padding);
|
||||||
BindB(*tensor.mat, sizeof(float), parallel);
|
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mode == kMap
|
// Mode == kMap
|
||||||
|
|
@ -482,7 +487,7 @@ static void ReadBatches(const BlobReader& reader,
|
||||||
// Aborts on error.
|
// Aborts on error.
|
||||||
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
||||||
Mode mode, std::vector<MatOwner>& mat_owners,
|
Mode mode, std::vector<MatOwner>& mat_owners,
|
||||||
hwy::ThreadPool& pool) {
|
ThreadingContext& ctx) {
|
||||||
if (mode == Mode::kMap) {
|
if (mode == Mode::kMap) {
|
||||||
MapPtr mapped = reader.file().Map();
|
MapPtr mapped = reader.file().Map();
|
||||||
if (mapped) return MapAll(tensors, mapped);
|
if (mapped) return MapAll(tensors, mapped);
|
||||||
|
|
@ -496,9 +501,11 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||||
// NOTE: this changes the stride of `mats`!
|
// NOTE: this changes the stride of `mats`!
|
||||||
AllocateAndBindAll(tensors, mode, mat_owners, pool);
|
AllocateAndBindAll(tensors, mode, mat_owners, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
|
|
||||||
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
|
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
|
||||||
|
|
||||||
const std::vector<IOBatch> batches =
|
const std::vector<IOBatch> batches =
|
||||||
|
|
@ -510,7 +517,7 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
const LoaderArgs& loader,
|
const LoaderArgs& loader,
|
||||||
const InferenceArgs& inference,
|
const InferenceArgs& inference,
|
||||||
std::vector<MatOwner>& mat_owners,
|
std::vector<MatOwner>& mat_owners,
|
||||||
hwy::ThreadPool& pool) {
|
ThreadingContext& ctx) {
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<TensorToRead> tensors;
|
std::vector<TensorToRead> tensors;
|
||||||
|
|
||||||
|
|
@ -529,13 +536,14 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
|
||||||
});
|
});
|
||||||
|
|
||||||
const Mode mode = ChooseMode(reader.file_bytes(), loader, inference);
|
const Mode mode =
|
||||||
|
ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator);
|
||||||
|
|
||||||
MapOrReadAll(tensors, reader, mode, mat_owners, pool);
|
MapOrReadAll(tensors, reader, mode, mat_owners, ctx);
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.Fixup");
|
PROFILER_ZONE("Startup.Fixup");
|
||||||
Fixup(mat_owners, pool);
|
Fixup(mat_owners, ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@
|
||||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||||
#include "io/blob_store.h" // BlobWriter
|
#include "io/blob_store.h" // BlobWriter
|
||||||
#include "util/mat.h" // MatPtr
|
#include "util/mat.h" // MatPtr
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "util/threading_context.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -299,11 +299,12 @@ struct LayerWeightsPtrs {
|
||||||
// Must be called after reading weights via `ForEachTensor`.
|
// Must be called after reading weights via `ForEachTensor`.
|
||||||
// TODO: exporters should bake this into the weights already.
|
// TODO: exporters should bake this into the weights already.
|
||||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||||
void Fixup(std::vector<MatOwner>& mat_owners);
|
void Fixup(std::vector<MatOwner>& mat_owners, const Allocator& allocator);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Copies att_weights from `attn_vec_einsum_w`.
|
// Copies att_weights from `attn_vec_einsum_w`.
|
||||||
void InitAttWeights(std::vector<MatOwner>& mat_owners);
|
void InitAttWeights(std::vector<MatOwner>& mat_owners,
|
||||||
|
const Allocator& allocator);
|
||||||
|
|
||||||
// For FFN. Fast, only updates pointers.
|
// For FFN. Fast, only updates pointers.
|
||||||
void SplitW1();
|
void SplitW1();
|
||||||
|
|
@ -426,7 +427,7 @@ struct WeightsPtrs {
|
||||||
// override for whether to map blobs or read them.
|
// override for whether to map blobs or read them.
|
||||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
const LoaderArgs& loader, const InferenceArgs& inference,
|
const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
||||||
|
|
||||||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||||
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
||||||
|
|
@ -434,7 +435,7 @@ struct WeightsPtrs {
|
||||||
private:
|
private:
|
||||||
// For reshaping file tensors to the shape expected by the code. This would
|
// For reshaping file tensors to the shape expected by the code. This would
|
||||||
// ideally already happen in the importer. Called by ReadFromBlobs.
|
// ideally already happen in the importer. Called by ReadFromBlobs.
|
||||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
||||||
}; // `WeightsPtrs`
|
}; // `WeightsPtrs`
|
||||||
#undef TENSOR_ARGS
|
#undef TENSOR_ARGS
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,11 +227,12 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
|
||||||
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
|
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
|
||||||
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
||||||
|
|
||||||
NestedPools& pools = ThreadingContext::Get().pools;
|
ThreadingArgs args;
|
||||||
|
ThreadingContext ctx(args);
|
||||||
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
|
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
|
||||||
pools);
|
ctx.pools);
|
||||||
|
|
||||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,9 @@ class BlobStoreTest : public testing::Test {};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
TEST(BlobStoreTest, TestReadWrite) {
|
TEST(BlobStoreTest, TestReadWrite) {
|
||||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
|
|
||||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||||
|
|
||||||
|
|
@ -92,7 +94,9 @@ TEST(BlobStoreTest, TestReadWrite) {
|
||||||
|
|
||||||
// Ensures padding works for any number of random-sized blobs.
|
// Ensures padding works for any number of random-sized blobs.
|
||||||
TEST(BlobStoreTest, TestNumBlobs) {
|
TEST(BlobStoreTest, TestNumBlobs) {
|
||||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
|
|
||||||
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
|
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
|
||||||
|
|
|
||||||
|
|
@ -84,19 +84,22 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
const Extents2D B_extents(N, K); // already transposed
|
const Extents2D B_extents(N, K); // already transposed
|
||||||
const Extents2D C_extents(M, N);
|
const Extents2D C_extents(M, N);
|
||||||
|
|
||||||
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
|
MatStorageT<TC> C_slow("c_slow_batch", C_extents, env.ctx.allocator,
|
||||||
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
|
MatStorageT<TC> C("c_batch", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||||
|
|
||||||
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
if (add) {
|
if (add) {
|
||||||
add_storage =
|
add_storage = GenerateMat<float>(Extents2D(1, N), env.ctx.allocator,
|
||||||
GenerateMat<float>(Extents2D(1, N), MatPadding::kPacked, pool);
|
MatPadding::kPacked, pool);
|
||||||
add_storage.SetScale(1.0f);
|
add_storage.SetScale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStorageT<TA> a = GenerateMat<TA>(A_extents, MatPadding::kOdd, pool);
|
MatStorageT<TA> a =
|
||||||
MatStorageT<TB> b_trans =
|
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool);
|
||||||
GenerateTransposedMat<TB>(B_extents, MatPadding::kOdd, pool);
|
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(
|
||||||
|
B_extents, env.ctx.allocator, MatPadding::kOdd, pool);
|
||||||
|
|
||||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||||
|
|
||||||
|
|
@ -151,10 +154,10 @@ void BenchAllMatMul() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadingContext& ctx = ThreadingContext::Get();
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
|
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
|
||||||
ctx.pools.PinString());
|
ctx.pools.PinString());
|
||||||
|
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
|
|
||||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||||
|
|
|
||||||
|
|
@ -999,6 +999,8 @@ struct TestShortDotsT {
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
const hn::ScalableTag<float> df; // for CallDot
|
const hn::ScalableTag<float> df; // for CallDot
|
||||||
|
|
||||||
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
CompressWorkingSet work;
|
CompressWorkingSet work;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
rng.seed(12345);
|
rng.seed(12345);
|
||||||
|
|
@ -1009,14 +1011,14 @@ struct TestShortDotsT {
|
||||||
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
|
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
|
||||||
// hence they require padding to one vector.
|
// hence they require padding to one vector.
|
||||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||||
MatStorageT<float> raw_w("raw_w", padded_num);
|
MatStorageT<float> raw_w("raw_w", padded_num, ctx.allocator);
|
||||||
MatStorageT<float> raw_v("raw_v", padded_num);
|
MatStorageT<float> raw_v("raw_v", padded_num, ctx.allocator);
|
||||||
MatStorageT<Packed> weights("weights", padded_num);
|
MatStorageT<Packed> weights("weights", padded_num, ctx.allocator);
|
||||||
const PackedSpan<Packed> w = weights.Span();
|
const PackedSpan<Packed> w = weights.Span();
|
||||||
MatStorageT<T> vectors("vectors", padded_num);
|
MatStorageT<T> vectors("vectors", padded_num, ctx.allocator);
|
||||||
const PackedSpan<T> v = vectors.Span();
|
const PackedSpan<T> v = vectors.Span();
|
||||||
|
|
||||||
MatStorageT<double> bufs("bufs", num);
|
MatStorageT<double> bufs("bufs", padded_num, ctx.allocator);
|
||||||
double* HWY_RESTRICT buf = bufs.Row(0);
|
double* HWY_RESTRICT buf = bufs.Row(0);
|
||||||
|
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||||
|
|
@ -1097,14 +1099,12 @@ void TestAllDot() {
|
||||||
|
|
||||||
constexpr size_t kMaxWorkers = 15;
|
constexpr size_t kMaxWorkers = 15;
|
||||||
|
|
||||||
// Reset with cap on workers because we only support `kMaxWorkers`.
|
// Limit workers because we only support `kMaxWorkers`.
|
||||||
ThreadingContext::ThreadHostileInvalidate();
|
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
threading_args.max_packages = 1;
|
threading_args.max_packages = 1;
|
||||||
threading_args.max_clusters = 1;
|
threading_args.max_clusters = 1;
|
||||||
threading_args.max_lps = kMaxWorkers - 1;
|
threading_args.max_lps = kMaxWorkers - 1;
|
||||||
ThreadingContext::SetArgs(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
ThreadingContext& ctx = ThreadingContext::Get();
|
|
||||||
|
|
||||||
{ // ensure no profiler zones are active
|
{ // ensure no profiler zones are active
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -1116,9 +1116,11 @@ void TestAllDot() {
|
||||||
|
|
||||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||||
const size_t num = 24 * 1024;
|
const size_t num = 24 * 1024;
|
||||||
MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
|
MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), ctx.allocator,
|
||||||
MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num),
|
MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), ctx.allocator,
|
||||||
|
MatPadding::kOdd);
|
||||||
|
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num), ctx.allocator,
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
std::array<DotStats, kMaxWorkers> all_stats;
|
std::array<DotStats, kMaxWorkers> all_stats;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -68,10 +69,11 @@ FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
|
||||||
|
|
||||||
template <typename MatT, size_t kOuter, size_t kInner>
|
template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
||||||
|
const Allocator& allocator,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
const Extents2D extents(kOuter, kInner);
|
const Extents2D extents(kOuter, kInner);
|
||||||
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents,
|
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents, allocator,
|
||||||
MatPadding::kPacked);
|
MatPadding::kPacked);
|
||||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
|
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
|
||||||
HWY_ASSERT(raw_mat);
|
HWY_ASSERT(raw_mat);
|
||||||
|
|
@ -109,10 +111,12 @@ void AssertClose(const FloatPtr& a, const FloatPtr& b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
constexpr size_t kOuter = 128 * 3;
|
constexpr size_t kOuter = 128 * 3;
|
||||||
constexpr size_t kInner = 128 * 5;
|
constexpr size_t kInner = 128 * 5;
|
||||||
auto mat = GenerateMat<float, kOuter, kInner>(0, pool);
|
auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
|
||||||
FloatPtr vec = GenerateVec<kInner>(0);
|
FloatPtr vec = GenerateVec<kInner>(0);
|
||||||
FloatPtr add = GenerateVec<kOuter>(0);
|
FloatPtr add = GenerateVec<kOuter>(0);
|
||||||
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
|
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
|
||||||
|
|
@ -124,11 +128,13 @@ void TestMatVecAdd() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestTwoMatVecAdd() {
|
void TestTwoMatVecAdd() {
|
||||||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
constexpr size_t kOuter = 128 * 3;
|
constexpr size_t kOuter = 128 * 3;
|
||||||
constexpr size_t kInner = 128 * 5;
|
constexpr size_t kInner = 128 * 5;
|
||||||
auto mat0 = GenerateMat<float, kOuter, kInner>(0, pool);
|
auto mat0 = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
|
||||||
auto mat1 = GenerateMat<float, kOuter, kInner>(1, pool);
|
auto mat1 = GenerateMat<float, kOuter, kInner>(1, ctx.allocator, pool);
|
||||||
FloatPtr vec = GenerateVec<kInner>(0);
|
FloatPtr vec = GenerateVec<kInner>(0);
|
||||||
FloatPtr add0 = GenerateVec<kOuter>(0);
|
FloatPtr add0 = GenerateVec<kOuter>(0);
|
||||||
FloatPtr add1 = GenerateVec<kOuter>(1);
|
FloatPtr add1 = GenerateVec<kOuter>(1);
|
||||||
|
|
@ -145,10 +151,13 @@ void TestTwoMatVecAdd() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestTwoOfsMatVecAddLoop() {
|
void TestTwoOfsMatVecAddLoop() {
|
||||||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
|
|
||||||
constexpr size_t kOuter = 128 * 3;
|
constexpr size_t kOuter = 128 * 3;
|
||||||
constexpr size_t kInner = 128 * 5;
|
constexpr size_t kInner = 128 * 5;
|
||||||
auto mat = GenerateMat<float, kOuter, kInner>(0, pool);
|
auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
|
||||||
FloatPtr vec = GenerateVec<kInner>(0);
|
FloatPtr vec = GenerateVec<kInner>(0);
|
||||||
FloatPtr add0 = GenerateVec<kOuter>(0);
|
FloatPtr add0 = GenerateVec<kOuter>(0);
|
||||||
FloatPtr add1 = GenerateVec<kOuter>(1);
|
FloatPtr add1 = GenerateVec<kOuter>(1);
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,8 @@ constexpr size_t kNR = 4;
|
||||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||||
static constexpr size_t kMaxMR = 4;
|
static constexpr size_t kMaxMR = 4;
|
||||||
|
|
||||||
// Mostly stateless, can be constructed on the fly by weights.cc, but captures
|
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the
|
||||||
// the singleton ThreadingContext to reduce MatMul call overhead.
|
// the ThreadingContext to shorten call sites.
|
||||||
class MMParallel {
|
class MMParallel {
|
||||||
public:
|
public:
|
||||||
static constexpr size_t kMaxPackages = 4;
|
static constexpr size_t kMaxPackages = 4;
|
||||||
|
|
@ -251,7 +251,7 @@ class MMStorage {
|
||||||
: // Per-worker copies of `partial` would be wasteful. We instead
|
: // Per-worker copies of `partial` would be wasteful. We instead
|
||||||
// allocate one instance of the maximum matrix extents because threads
|
// allocate one instance of the maximum matrix extents because threads
|
||||||
// write at false-sharing-free granularity.
|
// write at false-sharing-free granularity.
|
||||||
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator,
|
||||||
MatPadding::kOdd),
|
MatPadding::kOdd),
|
||||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||||
|
|
@ -259,7 +259,7 @@ class MMStorage {
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||||
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
|
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
||||||
|
|
||||||
if (allocator.ShouldBind()) {
|
if (allocator.ShouldBind()) {
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
|
|
||||||
|
|
@ -91,14 +91,15 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const size_t cols = A.Cols();
|
const size_t cols = A.Cols();
|
||||||
const size_t B_rows = B.Rows();
|
const size_t B_rows = B.Rows();
|
||||||
// Round up for DecompressAndZeroPad.
|
// Round up for DecompressAndZeroPad.
|
||||||
MatStorageT<float> a_batch("a_batch", A.Extents(), MatPadding::kOdd);
|
MatStorageT<float> a_batch("a_batch", A.Extents(), env.ctx.allocator,
|
||||||
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
|
|
||||||
MatPadding::kOdd);
|
|
||||||
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
|
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
|
||||||
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
|
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
||||||
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
|
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
||||||
MatPadding::kOdd);
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
for (size_t m = 0; m < A.Rows(); ++m) {
|
for (size_t m = 0; m < A.Rows(); ++m) {
|
||||||
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
|
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
|
||||||
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
|
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
|
||||||
|
|
@ -219,17 +220,21 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||||
const Extents2D C_extents(rows_ac, cols_bc);
|
const Extents2D C_extents(rows_ac, cols_bc);
|
||||||
|
|
||||||
MatStorageT<TA> A(GenerateMat<TA>(A_extents, MatPadding::kOdd, pool));
|
MatStorageT<TA> A(
|
||||||
|
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool));
|
||||||
// Must be packed because we call Span() on it.
|
// Must be packed because we call Span() on it.
|
||||||
MatStorageT<TB> BT(
|
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, env.ctx.allocator,
|
||||||
GenerateTransposedMat<TB>(B_extents, MatPadding::kPacked, pool));
|
MatPadding::kPacked, pool));
|
||||||
MatStorageT<TC> C_slow("C_slow", C_extents, MatPadding::kOdd);
|
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
|
||||||
MatStorageT<TC> C("C", C_extents, MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
|
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||||
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
|
||||||
MatStorageT<float> add_storage =
|
MatStorageT<float> add_storage =
|
||||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), MatPadding::kPacked, pool)
|
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
|
||||||
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
|
MatPadding::kPacked, pool)
|
||||||
|
: MatStorageT<float>("add", Extents2D(), env.ctx.allocator,
|
||||||
|
MatPadding::kPacked);
|
||||||
add_storage.SetScale(1.0f);
|
add_storage.SetScale(1.0f);
|
||||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||||
|
|
||||||
|
|
@ -252,12 +257,11 @@ void TestTiny() {
|
||||||
if (HWY_TARGET != first_target) return;
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
for (size_t max_packages : {1, 2}) {
|
for (size_t max_packages : {1, 2}) {
|
||||||
ThreadingContext::ThreadHostileInvalidate();
|
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
threading_args.bind = Tristate::kTrue;
|
threading_args.bind = Tristate::kTrue;
|
||||||
threading_args.max_packages = max_packages;
|
threading_args.max_packages = max_packages;
|
||||||
ThreadingContext::SetArgs(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
MatMulEnv env(ThreadingContext::Get());
|
MatMulEnv env(ctx);
|
||||||
NestedPools& pools = env.ctx.pools;
|
NestedPools& pools = env.ctx.pools;
|
||||||
|
|
||||||
if constexpr (GEMMA_DISABLE_TOPOLOGY) {
|
if constexpr (GEMMA_DISABLE_TOPOLOGY) {
|
||||||
|
|
@ -291,11 +295,10 @@ void TestAllMatMul() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadingContext::ThreadHostileInvalidate();
|
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
threading_args.bind = Tristate::kTrue;
|
threading_args.bind = Tristate::kTrue;
|
||||||
ThreadingContext::SetArgs(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
MatMulEnv env(ThreadingContext::Get());
|
MatMulEnv env(ctx);
|
||||||
NestedPools& pools = env.ctx.pools;
|
NestedPools& pools = env.ctx.pools;
|
||||||
pools.MaybeStartSpinning(threading_args.spin);
|
pools.MaybeStartSpinning(threading_args.spin);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1018,13 +1018,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
|
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
|
||||||
// Each output row is the average of a 4x4 block of input rows
|
// Each output row is the average of a 4x4 block of input rows
|
||||||
template <typename T>
|
template <typename T>
|
||||||
MatStorageT<T> AvgPool4x4(MatStorageT<T>& input) {
|
MatStorageT<T> AvgPool4x4(MatStorageT<T>& input, const Allocator& allocator) {
|
||||||
const Extents2D extents = input.Extents();
|
const Extents2D extents = input.Extents();
|
||||||
// Input validation
|
// Input validation
|
||||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||||
// Create output with 256 rows and same number of columns
|
// Create output with 256 rows and same number of columns
|
||||||
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||||
MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols),
|
MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols), allocator,
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
const size_t input_dim = 64; // Input is 64×64
|
const size_t input_dim = 64; // Input is 64×64
|
||||||
const size_t output_dim = 16; // Output is 16×16
|
const size_t output_dim = 16; // Output is 16×16
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,10 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
||||||
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
|
const Allocator& allocator, size_t qkv_dim, bool half_rope,
|
||||||
|
double base_frequency = 10000.0) {
|
||||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||||
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
|
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2, allocator);
|
||||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||||
const double freq_exponents =
|
const double freq_exponents =
|
||||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||||
|
|
|
||||||
|
|
@ -347,10 +347,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||||
ChooseWrapping(Model::GEMMA2_9B));
|
ChooseWrapping(Model::GEMMA2_9B));
|
||||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
MatStorageT<float> x("x", dim_qkv);
|
MatStorageT<float> x("x", dim_qkv, ctx.allocator);
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
gen.seed(0x12345678);
|
gen.seed(0x12345678);
|
||||||
|
|
@ -364,13 +366,13 @@ void TestRopeAndMulBy() {
|
||||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||||
constexpr float kmul = 1.0f;
|
constexpr float kmul = 1.0f;
|
||||||
|
|
||||||
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
|
||||||
MatStorageT<float> qactual("qactual", dim_qkv);
|
MatStorageT<float> qactual("qactual", dim_qkv, ctx.allocator);
|
||||||
MatStorageT<float> kexpected("kexpected", dim_qkv);
|
MatStorageT<float> kexpected("kexpected", dim_qkv, ctx.allocator);
|
||||||
MatStorageT<float> kactual("kactual", dim_qkv);
|
MatStorageT<float> kactual("kactual", dim_qkv, ctx.allocator);
|
||||||
MatStorageT<float> kactual2("kactual2", dim_qkv);
|
MatStorageT<float> kactual2("kactual2", dim_qkv, ctx.allocator);
|
||||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||||
config.layer_configs[0].qkv_dim,
|
ctx.allocator, config.layer_configs[0].qkv_dim,
|
||||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||||
for (size_t pos = 1; pos < 500; pos++) {
|
for (size_t pos = 1; pos < 500; pos++) {
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
|
||||||
|
|
||||||
image_tokens_ = std::make_unique<ImageTokens>(
|
image_tokens_ = std::make_unique<ImageTokens>(
|
||||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||||
MatPadding::kPacked);
|
env_->Env().ctx.allocator, MatPadding::kPacked);
|
||||||
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
|
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,7 @@ class GemmaModel {
|
||||||
image_tokens_.reset(new gcpp::ImageTokens(
|
image_tokens_.reset(new gcpp::ImageTokens(
|
||||||
"image_tokens",
|
"image_tokens",
|
||||||
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||||
gcpp::MatPadding::kOdd));
|
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
|
||||||
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
|
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
|
||||||
.verbosity = 0};
|
.verbosity = 0};
|
||||||
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
||||||
|
|
|
||||||
|
|
@ -78,10 +78,10 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
|
||||||
|
MatPadding padding) {
|
||||||
const bool is_nuq = mat.GetType() == Type::kNUQ;
|
const bool is_nuq = mat.GetType() == Type::kNUQ;
|
||||||
if (is_nuq) padding = MatPadding::kPacked;
|
if (is_nuq) padding = MatPadding::kPacked;
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
const size_t stride =
|
const size_t stride =
|
||||||
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
|
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
|
||||||
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
|
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
|
||||||
|
|
|
||||||
37
util/mat.h
37
util/mat.h
|
|
@ -443,7 +443,7 @@ class MatOwner {
|
||||||
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
||||||
// Ignores `padding` for NUQ tensors, which are always packed.
|
// Ignores `padding` for NUQ tensors, which are always packed.
|
||||||
// Thread-compatible, weights are allocated in parallel.
|
// Thread-compatible, weights are allocated in parallel.
|
||||||
void AllocateFor(MatPtr& mat, MatPadding padding);
|
void AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AlignedPtr<uint8_t[]> storage_;
|
AlignedPtr<uint8_t[]> storage_;
|
||||||
|
|
@ -455,13 +455,14 @@ class MatOwner {
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
class MatStorageT : public MatPtrT<MatT> {
|
class MatStorageT : public MatPtrT<MatT> {
|
||||||
public:
|
public:
|
||||||
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
|
MatStorageT(const char* name, Extents2D extents, const Allocator& allocator,
|
||||||
|
MatPadding padding)
|
||||||
: MatPtrT<MatT>(name, extents) {
|
: MatPtrT<MatT>(name, extents) {
|
||||||
if (extents.Area() != 0) owner_.AllocateFor(*this, padding);
|
if (extents.Area() != 0) owner_.AllocateFor(*this, allocator, padding);
|
||||||
}
|
}
|
||||||
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
|
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
|
||||||
MatStorageT(const char* name, size_t cols)
|
MatStorageT(const char* name, size_t cols, const Allocator& allocator)
|
||||||
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
|
: MatStorageT(name, Extents2D(1, cols), allocator, MatPadding::kPacked) {}
|
||||||
~MatStorageT() = default;
|
~MatStorageT() = default;
|
||||||
|
|
||||||
// Allow move for KVCache.
|
// Allow move for KVCache.
|
||||||
|
|
@ -472,5 +473,31 @@ class MatStorageT : public MatPtrT<MatT> {
|
||||||
MatOwner owner_;
|
MatOwner owner_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper for initializing members which are `MatStorageT<T>`: avoids having to
|
||||||
|
// specify Extents2D and MatPadding at each call site.
|
||||||
|
class MatFactory {
|
||||||
|
public:
|
||||||
|
// The constructor captures all the necessary arguments.
|
||||||
|
MatFactory(const char* name, size_t rows, size_t cols,
|
||||||
|
const Allocator& allocator, MatPadding padding = MatPadding::kOdd)
|
||||||
|
: name_(name),
|
||||||
|
extents_(rows, cols),
|
||||||
|
allocator_(allocator),
|
||||||
|
padding_(padding) {}
|
||||||
|
|
||||||
|
// Templated conversion so we do not have to specify the type in the
|
||||||
|
// member initializer.
|
||||||
|
template <typename T>
|
||||||
|
operator MatStorageT<T>() const {
|
||||||
|
return MatStorageT<T>(name_.c_str(), extents_, allocator_, padding_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string name_;
|
||||||
|
Extents2D extents_;
|
||||||
|
const Allocator& allocator_;
|
||||||
|
MatPadding padding_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||||
|
|
|
||||||
|
|
@ -15,60 +15,13 @@
|
||||||
|
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <mutex> // NOLINT
|
|
||||||
|
|
||||||
#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY
|
|
||||||
#include "hwy/profiler.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static ThreadingArgs s_args;
|
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
||||||
// Cannot use magic static because that does not support `Invalidate`, hence
|
: topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||||
// allocate manually.
|
BoundedSlice(args.skip_clusters, args.max_clusters),
|
||||||
static std::unique_ptr<ThreadingContext> s_ctx;
|
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||||
static std::mutex s_ctx_mutex;
|
allocator(topology, args.bind != Tristate::kFalse),
|
||||||
|
pools(topology, allocator, args.max_threads, args.pin) {}
|
||||||
/*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) {
|
|
||||||
s_ctx_mutex.lock();
|
|
||||||
HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late.
|
|
||||||
s_args = args;
|
|
||||||
s_ctx_mutex.unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ bool ThreadingContext::IsInitialized() {
|
|
||||||
s_ctx_mutex.lock();
|
|
||||||
const bool initialized = !!s_ctx;
|
|
||||||
s_ctx_mutex.unlock();
|
|
||||||
return initialized;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ ThreadingContext& ThreadingContext::Get() {
|
|
||||||
PROFILER_FUNC;
|
|
||||||
// We do not bother with double-checked locking because it requires an
|
|
||||||
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
|
||||||
// callers can cache the result and call less often.
|
|
||||||
s_ctx_mutex.lock();
|
|
||||||
if (HWY_UNLIKELY(!s_ctx)) {
|
|
||||||
s_ctx = std::make_unique<ThreadingContext>(PrivateToken());
|
|
||||||
}
|
|
||||||
s_ctx_mutex.unlock();
|
|
||||||
return *s_ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ void ThreadingContext::ThreadHostileInvalidate() {
|
|
||||||
// Deliberately avoid taking the lock so that tsan can warn if this is
|
|
||||||
// called concurrently with other calls to `Get`.
|
|
||||||
s_ctx.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would
|
|
||||||
// deadlock.
|
|
||||||
ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken)
|
|
||||||
: topology(BoundedSlice(s_args.skip_packages, s_args.max_packages),
|
|
||||||
BoundedSlice(s_args.skip_clusters, s_args.max_clusters),
|
|
||||||
BoundedSlice(s_args.skip_lps, s_args.max_lps)),
|
|
||||||
allocator(topology, s_args.bind != Tristate::kFalse),
|
|
||||||
pools(topology, allocator, s_args.max_threads, s_args.pin) {}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -85,43 +85,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Lazily-initialized singleton with support for passing in arguments from
|
struct ThreadingContext {
|
||||||
// `ThreadingArgs` and re-initializing with different arguments.
|
// Expected to be called early in the program, before threading starts.
|
||||||
class ThreadingContext {
|
explicit ThreadingContext(const ThreadingArgs& args);
|
||||||
struct PrivateToken {}; // avoids constructing directly
|
|
||||||
|
|
||||||
public:
|
|
||||||
// If not called, default arguments are used when `Get` initializes the
|
|
||||||
// singleton. Must not be called after `Get`, unless after a call to
|
|
||||||
// `ThreadHostileInvalidate`, because otherwise initialization already
|
|
||||||
// happened and the arguments would have no effect. Thread-safe, though this
|
|
||||||
// is expected to be called early in the program, before threading starts.
|
|
||||||
static void SetArgs(const ThreadingArgs& args);
|
|
||||||
|
|
||||||
// Returns whether `Get()` has already been called, typically used to avoid
|
|
||||||
// calling `SetArgs` after that, because it would assert.
|
|
||||||
static bool IsInitialized();
|
|
||||||
|
|
||||||
// Returns a reference to the singleton after initializing it if necessary.
|
|
||||||
// When initializing, uses the args passed to `SetArgs`, or defaults.
|
|
||||||
//
|
|
||||||
// It is safe to call this concurrently with other `Get`, but not with
|
|
||||||
// `SetArgs`, because that will warn if called after this, nor with
|
|
||||||
// `ThreadHostileInvalidate`, because that will invalidate the reference which
|
|
||||||
// callers of this may still be using. Such usage only occurs in tests,
|
|
||||||
// hence we prefer not to pull `std::shared_ptr` into the interface.
|
|
||||||
//
|
|
||||||
// To reduce overhead, callers should cache the result and call less often.
|
|
||||||
static ThreadingContext& Get();
|
|
||||||
|
|
||||||
// Invalidates the singleton before or after a call to `Get`. This allows
|
|
||||||
// changing the arguments between tests. Callers must again call `Get`
|
|
||||||
// afterwards to obtain an instance. WARNING: must not be called concurrently
|
|
||||||
// with other calls to `Get` and usages of its return value.
|
|
||||||
// Also useful to suppress memory leak warnings in tests.
|
|
||||||
static void ThreadHostileInvalidate();
|
|
||||||
|
|
||||||
explicit ThreadingContext(PrivateToken); // only called via `Get`.
|
|
||||||
|
|
||||||
BoundedTopology topology;
|
BoundedTopology topology;
|
||||||
Allocator allocator;
|
Allocator allocator;
|
||||||
|
|
|
||||||
|
|
@ -280,8 +280,6 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
const double t1 = hwy::platform::Now();
|
const double t1 = hwy::platform::Now();
|
||||||
|
|
||||||
// TODO(janwas): enable after Highway update
|
|
||||||
#if 0
|
|
||||||
if (pool.AutoTuneComplete()) {
|
if (pool.AutoTuneComplete()) {
|
||||||
hwy::Span<hwy::CostDistribution> cd = pool.AutoTuneCosts();
|
hwy::Span<hwy::CostDistribution> cd = pool.AutoTuneCosts();
|
||||||
std::vector<double> costs;
|
std::vector<double> costs;
|
||||||
|
|
@ -308,10 +306,6 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
} else {
|
} else {
|
||||||
HWY_WARN("Auto-tuning did not complete yet.");
|
HWY_WARN("Auto-tuning did not complete yet.");
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
(void)t0;
|
|
||||||
(void)t1;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
static const bool have_stop = hwy::platform::HaveTimerStop(cpu100);
|
static const bool have_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
|
@ -383,7 +377,9 @@ TEST(ThreadingTest, BenchJoin) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
NestedPools& pools = ThreadingContext::Get().pools;
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
NestedPools& pools = ctx.pools;
|
||||||
// Use last package because the main thread has been pinned to it.
|
// Use last package because the main thread has been pinned to it.
|
||||||
const size_t pkg_idx = pools.NumPackages() - 1;
|
const size_t pkg_idx = pools.NumPackages() - 1;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue