Merge branch 'dev' into upgrade-github-actions-node24

This commit is contained in:
Salman Chishti 2025-12-16 14:47:59 +00:00 committed by GitHub
commit a4c78d4454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
93 changed files with 4095 additions and 1810 deletions

View File

@ -66,6 +66,7 @@ cc_library(
srcs = ["util/topology.cc"], srcs = ["util/topology.cc"],
hdrs = ["util/topology.h"], hdrs = ["util/topology.h"],
deps = [ deps = [
"@highway//:bit_set",
"@highway//:hwy", "@highway//:hwy",
"@highway//:topology", "@highway//:topology",
], ],
@ -111,6 +112,7 @@ cc_library(
":threading", ":threading",
":topology", ":topology",
":zones", ":zones",
"//io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:profiler", "@highway//:profiler",
@ -139,7 +141,7 @@ cc_test(
":kv_cache", ":kv_cache",
":mat", ":mat",
":matmul", ":matmul",
":ops", ":test_util",
":threading_context", ":threading_context",
":weights", ":weights",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
@ -172,8 +174,11 @@ cc_library(
name = "test_util", name = "test_util",
hdrs = ["util/test_util.h"], hdrs = ["util/test_util.h"],
deps = [ deps = [
":basics",
":mat",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:stats", "@highway//:stats",
], ],
) )
@ -440,9 +445,9 @@ cc_test(
":gemma_lib", ":gemma_lib",
":mat", ":mat",
":ops", ":ops",
":query",
":test_util", ":test_util",
":threading_context", ":threading_context",
":zones",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:test_util", "//compression:test_util",
"//compression:types", "//compression:types",
@ -519,32 +524,70 @@ cc_library(
], ],
) )
cc_test(
name = "kv_cache_test",
srcs = ["gemma/kv_cache_test.cc"],
deps = [
":configs",
":gemma_args",
":kv_cache",
":threading_context",
"//testing/base/public:gunit_main",
"@highway//:hwy",
],
)
cc_library(
name = "query",
hdrs = ["gemma/query.h"],
deps = [
":basics",
":gemma_args",
":kv_cache",
"@highway//:hwy",
],
)
cc_library( cc_library(
name = "gemma_args", name = "gemma_args",
hdrs = ["gemma/gemma_args.h"], hdrs = ["gemma/gemma_args.h"],
deps = [ deps = [
":args", ":args",
":basics", ":basics",
":configs",
":mat", ":mat",
":threading_context",
"//io", "//io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
], ],
) )
cc_test(
name = "gemma_args_test",
srcs = ["gemma/gemma_args_test.cc"],
deps = [
":gemma_args",
"@googletest//:gtest_main", # buildcleaner: keep
],
)
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
"gemma/attention.cc", "gemma/attention.cc",
"gemma/flash_attention.cc", "gemma/flash_attention.cc",
"gemma/gemma.cc", "gemma/gemma.cc",
"gemma/tensor_stats.cc",
"gemma/vit.cc", "gemma/vit.cc",
], ],
hdrs = [ hdrs = [
"gemma/activations.h", "gemma/activations.h",
"gemma/attention.h", "gemma/attention.h",
"gemma/flash_attention.h", "gemma/flash_attention.h",
"gemma/flash_structs.h",
"gemma/gemma.h", "gemma/gemma.h",
"gemma/tensor_stats.h",
"gemma/vit.h", "gemma/vit.h",
], ],
exec_properties = { exec_properties = {
@ -555,6 +598,7 @@ cc_library(
"gemma/gemma-inl.h", "gemma/gemma-inl.h",
], ],
deps = [ deps = [
":allocator",
":basics", ":basics",
":configs", ":configs",
":gemma_args", ":gemma_args",
@ -564,6 +608,7 @@ cc_library(
":matmul_env", ":matmul_env",
":model_store", ":model_store",
":ops", ":ops",
":query",
":threading", ":threading",
":threading_context", ":threading_context",
":weights", ":weights",
@ -577,8 +622,34 @@ cc_library(
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", # timer "@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort", "@highway//hwy/contrib/sort:vqsort",
] +
[
],
)
cc_test(
name = "gemma_lib_test",
srcs = ["gemma/attention_test.cc"],
# MatMulEnvs are up to 20GB large.
tags = ["requires-mem:28g"],
deps = [
":configs",
":gemma_args",
":gemma_lib",
":kv_cache",
":mat",
":matmul_env",
":test_util",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
], ],
) )
@ -604,7 +675,6 @@ cc_library(
":gemma_args", ":gemma_args",
":gemma_lib", ":gemma_lib",
":matmul_env", ":matmul_env",
":ops",
":threading_context", ":threading_context",
":tokenizer", ":tokenizer",
"@google_benchmark//:benchmark", "@google_benchmark//:benchmark",

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if ## Note: absl needs to be installed by sentencepiece. This will only happen if
@ -82,6 +82,7 @@ set(SOURCES
gemma/configs.h gemma/configs.h
gemma/flash_attention.cc gemma/flash_attention.cc
gemma/flash_attention.h gemma/flash_attention.h
gemma/flash_structs.h
gemma/gemma_args.h gemma/gemma_args.h
gemma/gemma-inl.h gemma/gemma-inl.h
gemma/gemma.cc gemma/gemma.cc
@ -221,6 +222,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc compression/nuq_test.cc
compression/sfp_test.cc compression/sfp_test.cc
evals/gemma_test.cc evals/gemma_test.cc
gemma/gemma_args_test.cc
gemma/flash_attention_test.cc gemma/flash_attention_test.cc
gemma/tensor_info_test.cc gemma/tensor_info_test.cc
io/blob_store_test.cc io/blob_store_test.cc

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version. # Require a more recent version.
git_override( git_override(
module_name = "highway", module_name = "highway",
commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579", commit = "3b680cde3a556bead9cc23c8f595d07a44d5a0d5",
remote = "https://github.com/google/highway", remote = "https://github.com/google/highway",
) )

View File

@ -55,7 +55,6 @@ Guidelines](https://opensource.google.com/conduct/).
- CPU-only inference for: Gemma 2-3, PaliGemma 2. - CPU-only inference for: Gemma 2-3, PaliGemma 2.
- Sampling with TopK and temperature. - Sampling with TopK and temperature.
- Backward pass (VJP) and Adam optimizer for Gemma research.
- Optimizations - Optimizations
@ -452,7 +451,7 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma) FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
``` ```
@ -520,13 +519,19 @@ Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka. It was removed in 2025-09. Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
Gemma-2 support was implemented in June/July 2024 with the help of several Gemma 2 support was implemented in June/July 2024 with the help of several
people. people including Daniel Keysers and Phil Culliton.
PaliGemma support was implemented in September 2024 with contributions from PaliGemma support was implemented in September 2024 with contributions from
Daniel Keysers. Daniel Keysers.
Gemma 3 support was implemented in January-March 2025 with contributions from
Daniel Keysers and Phil Culliton.
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many [Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
improvements, including major gains in efficiency, since the initial release. improvements, including major gains in efficiency, since the initial release.
[Phil Culliton](mailto:philculliton@google.com) has worked on model releases,
eval and validation, GTM, and quantization, since the initial release.
This is not an officially supported Google product. This is not an officially supported Google product.

View File

@ -1,4 +1,4 @@
# Weight compression and analysis. # Compressed tensor types.
load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test") load("@rules_cc//cc:cc_test.bzl", "cc_test")
@ -101,13 +101,11 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":distortion",
":int", ":int",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:test_util", "//:test_util",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark",
], ],
) )
@ -135,8 +133,8 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":compress",
":distortion", ":distortion",
":sfp",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:test_util", "//:test_util",
"@highway//:hwy", "@highway//:hwy",
@ -182,7 +180,6 @@ cc_library(
"//:mat", "//:mat",
"//:threading_context", "//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",
"@highway//:stats", "@highway//:stats",
"@highway//:thread_pool", "@highway//:thread_pool",
@ -209,19 +206,3 @@ cc_test(
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
], ],
) )
# For internal experimentation
cc_library(
name = "analyze",
textual_hdrs = ["analyze.h"],
deps = [
":int",
":nuq",
":sfp",
":types",
"@highway//:hwy",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)

View File

@ -1,238 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard to placate lint.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h> // memcpy
#include <cmath> // std::signbit
#include <cstdlib> // std::abs
#include <vector>
#include "compression/types.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif
#include "compression/nuq-inl.h"
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class PerThread {
public:
void NotifyGroup(const float* group) {
constexpr size_t kGroupSize = NuqStream::kGroupSize;
hwy::Stats s_group;
for (size_t i = 0; i < kGroupSize; ++i) {
// Skip zero so we can see the lowest actual magnitude
if (group[i] == 0.0f || group[i] == -0.0f) continue;
s_all_.Notify(group[i]);
s_group.Notify(group[i]);
num_tiny_ += std::abs(group[i]) < 1e-3f;
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
const uint32_t binary32 =
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
// const int32_t exp = (binary32 >> 23) - 127;
b_exp256_.Notify(binary32 >> 23);
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
b_m4_.Notify(m4);
}
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
s_group_mins_.Notify(s_group.Min());
s_group_maxs_.Notify(s_group.Max());
float desc[kGroupSize];
memcpy(desc, group, kGroupSize * sizeof(group[0]));
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
// Find largest |max/min| (dynamic range)
float max_ratio = 0.0f;
for (size_t i = 0; i < kGroupSize; ++i) {
if (desc[i] != 0.0f && desc[i] != -0.0f) {
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
}
}
s_group_max_vs_min_.Notify(max_ratio);
// Relative errors
float diffs[kGroupSize];
for (size_t i = 0; i < kGroupSize - 1; ++i) {
// was in descending order. Avoid div by 0. Ignore sign changes.
diffs[i] = std::abs(desc[i]) < 1e-5
? 0
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
}
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
s_cut15_.Notify(diffs[15]);
}
void Assimilate(const PerThread& other) {
num_tiny_ += other.num_tiny_;
s_all_.Assimilate(other.s_all_);
s_group_ranges_.Assimilate(other.s_group_ranges_);
s_group_mins_.Assimilate(other.s_group_mins_);
s_group_maxs_.Assimilate(other.s_group_maxs_);
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
s_erange_.Assimilate(other.s_erange_);
s_km_1_.Assimilate(other.s_km_1_);
s_km_2_.Assimilate(other.s_km_2_);
s_cut15_.Assimilate(other.s_cut15_);
b_magn100_.Assimilate(other.b_magn100_);
b_exp256_.Assimilate(other.b_exp256_);
b_m4_.Assimilate(other.b_m4_);
}
void PrintAll() {
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "num tiny %zu\n", num_tiny_);
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
// b_magn100_.Print("magn100");
// b_exp256_.Print("exp");
// b_m4_.Print("mantissa bits4");
fprintf(stderr, "\n");
}
private:
size_t num_tiny_ = 0;
hwy::Stats s_all_;
hwy::Stats s_group_ranges_;
hwy::Stats s_group_mins_;
hwy::Stats s_group_maxs_;
hwy::Stats s_group_max_vs_min_;
hwy::Stats s_erange_;
hwy::Stats s_km_1_;
hwy::Stats s_km_2_;
hwy::Stats s_cut15_;
hwy::Bins<100> b_magn100_;
hwy::Bins<256> b_exp256_;
hwy::Bins<16> b_m4_;
uint8_t padding_[64]; // prevent false sharing
};
class PerLayer {
public:
void NotifyGroup(const float* group) {
for (size_t i = 0; i < NuqStream::kGroupSize; ++i) {
s_layer_.Notify(group[i]);
}
}
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
const float layer_mean = s_layer_.Mean();
const float layer_sd = s_layer_.StandardDeviation();
for (size_t i = 0; i < weights_per_layer; ++i) {
num_outliers_ +=
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
}
}
const hwy::Stats& GetStats() const { return s_layer_; }
size_t Outliers() const { return num_outliers_; }
private:
hwy::Stats s_layer_;
size_t num_outliers_ = 0;
uint8_t padding[64]; // prevent false sharing
};
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
size_t weights_per_layer,
hwy::ThreadPool& pool) {
std::vector<PerThread> tls;
std::vector<PerLayer> per_layer(layers);
const auto init = [&](size_t num_threads) {
tls.resize(num_threads);
return true;
};
pool.Run(0, static_cast<uint32_t>(layers), init,
[&](uint32_t idx_layer, size_t idx_thread) {
PerThread& self = tls[idx_thread];
const float* layer = &mat[idx_layer * weights_per_layer];
// For each whole group in the layer
for (size_t group_start = 0;
group_start + NuqStream::kGroupSize <= weights_per_layer;
group_start += NuqStream::kGroupSize) {
const float* group = layer + group_start;
per_layer[idx_layer].NotifyGroup(group);
self.NotifyGroup(group);
}
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
});
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "\n------------%s\n", caption);
for (size_t i = 1; i < pool.NumWorkers(); ++i) {
tls[0].Assimilate(tls[i]);
}
tls[0].PrintAll();
hwy::Stats s_layer_ranges;
hwy::Stats s_layer_outliers;
for (size_t i = 0; i < layers; ++i) {
fprintf(stderr, " %02zu %s\n", i,
per_layer[i].GetStats().ToString(skip).c_str());
const float range =
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
s_layer_ranges.Notify(range);
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
weights_per_layer);
}
fprintf(stderr, "layer outliers%% %s\n",
s_layer_outliers.ToString(skip).c_str());
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_

View File

@ -82,6 +82,8 @@ struct CompressTraits<float> {
hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF); hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF);
} }
static float ToFloatSlow(const Packed x) { return x; }
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>> template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
static HWY_INLINE void Load2(DBF16 dbf16, static HWY_INLINE void Load2(DBF16 dbf16,
const PackedSpan<const Packed>& packed, const PackedSpan<const Packed>& packed,
@ -254,6 +256,10 @@ struct CompressTraits<BF16> {
packed.ptr + packed_ofs); packed.ptr + packed_ofs);
} }
static float ToFloatSlow(const Packed x) {
return hwy::ConvertScalarTo<float>(x);
}
template <class DBF16, HWY_IF_BF16_D(DBF16)> template <class DBF16, HWY_IF_BF16_D(DBF16)>
static HWY_INLINE void Load2(DBF16 dbf16, static HWY_INLINE void Load2(DBF16 dbf16,
const PackedSpan<const Packed>& packed, const PackedSpan<const Packed>& packed,
@ -397,6 +403,27 @@ struct CompressTraits<SfpStream> {
} }
} }
// NOTE: this does not take into account the per-tensor scale.
static float ToFloatSlow(const Packed x) {
uint32_t sfp = x.byte;
HWY_ASSERT(sfp != 0x80); // -0 is reserved
const uint32_t sign32 = (sfp & 0x80) << 24;
sfp &= 0x7F;
const bool large_e = sfp >= 64;
const size_t m_bits = large_e ? 3 : 2;
uint32_t m = sfp & ((1u << m_bits) - 1u);
size_t e = sfp >> m_bits;
if (sfp == 0) return 0.0f;
const uint32_t e_bias = large_e ? 15 : 23;
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
const uint32_t mnt32 = m << (23 - m_bits);
const uint32_t binary32 = sign32 | exp32 | mnt32;
float result;
hwy::CopySameSize(&binary32, &result);
return result;
}
template <class D> // Caller checks this is f32 or bf16 template <class D> // Caller checks this is f32 or bf16
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed, static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0, const size_t packed_ofs, hn::Vec<D>& raw0,
@ -437,6 +464,12 @@ struct CompressTraits<I8Stream> {
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1); IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
} }
static float ToFloatSlow(const Packed x) {
HWY_DASSERT(!"Not supported - requires a stream");
return 0.0f;
}
// Store2 is not yet implemented.
template <class D, typename Raw> template <class D, typename Raw>
static HWY_INLINE void DecompressAndZeroPad( static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs, D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
@ -483,6 +516,10 @@ struct CompressTraits<NuqStream> {
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
} }
static float ToFloatSlow(const Packed x) {
HWY_DASSERT(!"Not supported - requires a stream");
return 0.0f;
}
// Store2 is not yet implemented. // Store2 is not yet implemented.
template <class D, typename Raw> template <class D, typename Raw>
@ -604,6 +641,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
} }
// NOTE: the following are the recommended way to iterate over arrays of
// potentially compressed elements, including remainder handling. Prefer them
// over calling `Decompress2` directly, which does not handle remainders.
// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
// user code expressed as lambdas.
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from // Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to // both into groups of four vectors with lane type `Kernel::Raw`, passes them to
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes // `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
@ -733,8 +777,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
comp3); comp3);
} }
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h. // Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`.
// `DF` is the decompressed type, typically `float`. // `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
template <class DF, typename T, class Func> template <class DF, typename T, class Func>
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num, Func&& func) { size_t num, Func&& func) {
@ -773,6 +817,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
} }
// One extra argument. `DF` is the decompressed type, typically `float`. // One extra argument. `DF` is the decompressed type, typically `float`.
// Calls `func(df, v_inout, v1)`.
template <class DF, typename T, typename T1, class Func> template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num, size_t num,
@ -821,8 +866,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
} }
} }
// Two extra arguments. `DF` is the decompressed type, typically `float`.
// Calls `func(df, v_inout, v1, v2)`.
template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressInplace(
DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
const auto packed_inout = MakeSpan(inout, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed2 = MakeSpan(p2, p2_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v0, v1;
Decompress2(df, packed_inout, i, v0, v1);
VF v10, v11;
Decompress2(df, packed1, i, v10, v11);
VF v20, v21;
Decompress2(df, packed2, p2_ofs + i, v20, v21);
const VF out0 = func(df, v0, v10, v20);
const VF out1 = func(df, v1, v11, v21);
Compress2(df, out0, out1, packed_inout, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_inout + NF);
hn::Store(hn::Zero(df), df, buf1 + NF);
hn::Store(hn::Zero(df), df, buf2 + NF);
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
const VF v0 = hn::Load(df, buf_inout);
const VF v1 = hn::Load(df, buf_inout + NF);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF v20 = hn::Load(df, buf2);
const VF v21 = hn::Load(df, buf2 + NF);
const VF out0 = func(df, v0, v10, v20);
const VF out1 = func(df, v1, v11, v21);
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
}
}
}
// Single input, separate output. `DF` is the decompressed type, typically // Single input, separate output. `DF` is the decompressed type, typically
// `float`. // `float`. Calls `func(df, v1)`.
template <class DF, typename T, typename T1, class Func> template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1, const T1* HWY_RESTRICT p1,
@ -863,7 +964,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
} }
} }
// Two inputs. `DF` is the decompressed type, typically `float`. // Two inputs, separate output. `DF` is the decompressed type, typically
// `float`. Calls `func(df, v1, v2)`.
template <class DF, typename T, typename T1, typename T2, class Func> template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1, const T1* HWY_RESTRICT p1,
@ -912,7 +1014,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
} }
} }
// Three inputs. `DF` is the decompressed type, typically `float`. // Three inputs, separate output. `DF` is the decompressed type, typically
// `float`. Calls `func(df, v1, v2, v3)`.
template <class DF, typename T, typename T1, typename T2, typename T3, template <class DF, typename T, typename T1, typename T2, typename T3,
class Func> class Func>
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,

View File

@ -259,6 +259,13 @@ class TestDecompressAndCompress {
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
// `out` already contains v + v1.
Decompress2AndCompressInplace(
df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0,
[](DF, VF v, VF /*v1*/, VF v2)
HWY_ATTR -> VF { return hn::Add(v, v2); });
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
Decompress1AndCompressTo(df, out.get(), num, p.get(), Decompress1AndCompressTo(df, out.get(), num, p.get(),
[](DF, VF v) HWY_ATTR -> VF { return v; }); [](DF, VF v) HWY_ATTR -> VF { return v; });
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num); HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);

View File

@ -480,9 +480,12 @@ class NibbleCodec {
static_assert(kHalf <= 1); static_assert(kHalf <= 1);
const size_t N = hn::Lanes(d8); const size_t N = hn::Lanes(d8);
constexpr size_t kMaxN = hn::MaxLanes(d8); constexpr size_t kMaxN = hn::MaxLanes(d8);
constexpr bool kPermuteAcrossBlocks =
HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86;
// For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of // For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of
// bounds for TableLookupBytes. We instead BroadcastBlock<1> there. // bounds for TableLookupBytes. We instead BroadcastBlock<1> there.
constexpr uint8_t kAdd = kMaxN < 64 ? kHalf * kMaxN / 4 : 0; constexpr uint8_t kAdd =
kMaxN < 64 || kPermuteAcrossBlocks ? kHalf * kMaxN / 4 : 0;
// The only performance-portable op to replicate bytes is TableLookupBytes, // The only performance-portable op to replicate bytes is TableLookupBytes,
// but this only works if vectors are 128-bit or we first BroadcastBlock, // but this only works if vectors are 128-bit or we first BroadcastBlock,
// which only works for <= 512-bit vectors. For scalable vectors, we // which only works for <= 512-bit vectors. For scalable vectors, we
@ -506,7 +509,7 @@ class NibbleCodec {
} else if constexpr (kMaxN <= 16) { // <= 128-bit } else if constexpr (kMaxN <= 16) { // <= 128-bit
// No BroadcastBlock, we anyway only have one block. // No BroadcastBlock, we anyway only have one block.
return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
} else if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { } else if constexpr (kPermuteAcrossBlocks) {
// No BroadcastBlock, can directly permute across blocks. // No BroadcastBlock, can directly permute across blocks.
return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4));
} else { // 256..512-bit, no efficient TableLookupLanes } else { // 256..512-bit, no efficient TableLookupLanes

View File

@ -26,7 +26,6 @@ cc_library(
"//io", "//io",
"//io:blob_store", "//io:blob_store",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool",
], ],
) )

View File

@ -37,37 +37,23 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/sfp-inl.h" #include "compression/compress-inl.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
// Decode HWY_INLINE_VAR constexpr bool kPrint = false;
float F32FromSFP8(uint32_t sfp) {
HWY_ASSERT(sfp < 256);
HWY_ASSERT(sfp != 0x80); // -0 is reserved
const uint32_t sign32 = (sfp & 0x80) << 24; static float F32FromSFP8(uint32_t sfp) {
sfp &= 0x7F; return CompressTraits<SfpStream>::ToFloatSlow(
const bool large_e = sfp >= 64; SfpStream{static_cast<uint8_t>(sfp)});
const size_t m_bits = large_e ? 3 : 2;
uint32_t m = sfp & ((1u << m_bits) - 1u);
size_t e = sfp >> m_bits;
if (sfp == 0) return 0.0f;
const uint32_t e_bias = large_e ? 15 : 23;
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
const uint32_t mnt32 = m << (23 - m_bits);
const uint32_t binary32 = sign32 | exp32 | mnt32;
float result;
hwy::CopySameSize(&binary32, &result);
return result;
} }
// Used for HWY_AVX3_DL and newer. // Used for HWY_AVX3_DL and newer.
void PrintTables() { void PrintTables() {
if (HWY_ONCE && false) { if (HWY_ONCE && kPrint) {
uint8_t hi[128]; uint8_t hi[128];
fprintf(stderr, "lo\n"); fprintf(stderr, "lo\n");
for (uint32_t sfp = 0; sfp < 128; ++sfp) { for (uint32_t sfp = 0; sfp < 128; ++sfp) {
@ -92,7 +78,7 @@ void TestAllUnique() {
unique.insert(F32FromSFP8(sfp)); unique.insert(F32FromSFP8(sfp));
} }
HWY_ASSERT_EQ(size_t{255}, unique.size()); HWY_ASSERT_EQ(size_t{255}, unique.size());
if (false) { if (kPrint) {
for (float f : unique) { for (float f : unique) {
fprintf(stderr, "%e\n", f); fprintf(stderr, "%e\n", f);
} }
@ -163,7 +149,7 @@ HWY_INLINE uint32_t SFP8FromF32(float f) {
if (m == 0) m = 1; if (m == 0) m = 1;
} }
if (false) { if (kPrint) {
fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n", fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n",
org_binary32, round, rounded, e, m, large_e); org_binary32, round, rounded, e, m, large_e);
} }

View File

@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding); MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) { Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding); MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) { Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {

View File

@ -23,7 +23,9 @@ using json = nlohmann::json;
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> { class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public: public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path summarize_text; Path summarize_text;
Path cross_entropy; Path cross_entropy;
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
if (!benchmark_args.summarize_text.Empty()) { if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text); return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) { } else if (!benchmark_args.cross_entropy.Empty()) {

View File

@ -36,30 +36,29 @@
namespace gcpp { namespace gcpp {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const GemmaArgs& args)
const InferenceArgs& inference) : initializer_value_(gcpp::InternalInit()),
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { ctx_(args.threading),
env_(ctx_),
gemma_(args, ctx_) {
const ModelConfig& config = gemma_.Config(); const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
if (inference.verbosity >= 2) { if (args.inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
ctx_);
} }
if (args.inference.verbosity >= 3) env_.print_best = true;
if (args.inference.verbosity >= 4) env_.print_config = true;
runtime_config_ = { runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens, .max_generated_tokens = args.inference.max_generated_tokens,
.temperature = inference.temperature, .temperature = args.inference.temperature,
.verbosity = inference.verbosity, .verbosity = args.inference.verbosity,
}; };
inference.CopyTo(runtime_config_); args.inference.CopyTo(runtime_config_);
} }
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
InferenceArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) { QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result; QueryResult result;
@ -229,19 +228,19 @@ static constexpr const char* CompiledConfig() {
} }
} }
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const InferenceArgs& inference, const ModelConfig& config,
const WeightsPtrs::Mode weight_read_mode, const WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx) { const ThreadingContext& ctx) {
threading.Print(inference.verbosity); args.threading.Print(args.inference.verbosity);
loader.Print(inference.verbosity); args.loader.Print(args.inference.verbosity);
inference.Print(inference.verbosity); args.inference.Print(args.inference.verbosity);
fprintf( fprintf(stderr,
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", "Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16), config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode)); static_cast<int>(args.loader.map),
WeightsPtrs::ToString(weight_read_mode));
if (inference.verbosity >= 2) { if (args.inference.verbosity >= 2) {
time_t now = time(nullptr); time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
@ -254,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"Instruction set : %s (%zu bits)\n" "Instruction set : %s (%zu bits)\n"
"Compiled config : %s, profiler %d\n" "Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n", "Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind), dt, cpu100, static_cast<int>(args.threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(), ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.cache_info.VectorBytes() * 8, CompiledConfig(), ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
@ -262,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
} }
} }
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
}
} // namespace gcpp } // namespace gcpp

View File

@ -23,7 +23,7 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h" // IWYU pragma: export
#include "gemma/tokenizer.h" // WrapAndTokenize #include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/threading_context.h" #include "util/threading_context.h"
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
// Convenience class to load a model and run inference. // Convenience class to load a model and run inference.
class GemmaEnv { class GemmaEnv {
public: public:
// Calls the other constructor with *Args arguments initialized from argv. explicit GemmaEnv(const GemmaArgs& args);
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
MatMulEnv& Env() { return env_; } MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const { size_t MaxGeneratedTokens() const {
@ -125,6 +123,8 @@ class GemmaEnv {
MatMulEnv& MutableEnv() { return env_; } MatMulEnv& MutableEnv() { return env_; }
private: private:
// This is used to ensure that InternalInit is called before anything else.
int initializer_value_ = 0;
ThreadingContext ctx_; ThreadingContext ctx_;
MatMulEnv env_; MatMulEnv env_;
Gemma gemma_; Gemma gemma_;
@ -135,12 +135,9 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec. // Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const InferenceArgs& inference, const ModelConfig& config,
WeightsPtrs::Mode weight_read_mode, WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx); const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
} // namespace gcpp } // namespace gcpp

View File

@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
->UseRealTime(); ->UseRealTime();
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
env.SetMaxGeneratedTokens(256); env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env; gcpp::s_env = &env;

View File

@ -31,7 +31,9 @@ namespace gcpp {
class PromptArgs : public ArgsBase<PromptArgs> { class PromptArgs : public ArgsBase<PromptArgs> {
public: public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path layers_output; // optional Path layers_output; // optional
std::string prompt; std::string prompt;
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
}; };
int Run(int argc, char** argv) { int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv); ConsumedArgs consumed(argc, argv);
const GemmaArgs args(argc, argv, consumed);
const PromptArgs prompt_args(argc, argv, consumed);
AbortIfInvalidArgs(prompt_args); AbortIfInvalidArgs(prompt_args);
consumed.AbortIfUnconsumed();
json json_output; json json_output;
GemmaEnv env(argc, argv); GemmaEnv env(args);
env.MutableConfig().layers_output = env.MutableConfig().layers_output =
prompt_args.layers_output.Empty() prompt_args.layers_output.Empty()
? LayersOutputFunc() ? LayersOutputFunc()

View File

@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test {
} }
}; };
TEST_F(GemmaBatchBench, RandomQuestionsBatched) { std::vector<std::string> GenerateInputs() {
std::vector<std::string> prompts = { std::vector<std::string> prompts = {
{"Describe dynamic programming."}, {"Describe dynamic programming."},
{"Explain how electric cars work."}, {"Explain how electric cars work."},
@ -122,33 +122,38 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
inputs.push_back(prompts[qpos++]); inputs.push_back(prompts[qpos++]);
if (qpos == prompts.size()) qpos = 0; if (qpos == prompts.size()) qpos = 0;
} }
s_env->SetMaxGeneratedTokens(24); return inputs;
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
// Run again: prefill will be faster due to autotuning. Fewer decode steps
// because those are already fast.
s_env->SetMaxGeneratedTokens(2);
responses = BatchGemmaReply(inputs);
PROFILER_PRINT_RESULTS();
} }
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
s_env->SetMaxGeneratedTokens(12);
const std::vector<std::string> inputs = GenerateInputs();
// Run multiple times so that auto-tuning is closer to complete.
for (size_t rep = 0; rep < 4; ++rep) {
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
}
}
} // namespace } // namespace
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n"); fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env; gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }

View File

@ -22,7 +22,6 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
// Requires argc/argv, hence do not use `SetUpTestSuite`. // Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) { static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once. HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv); ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str()); fprintf(stderr, "Using %s\n", config.Specifier().c_str());
} }
@ -130,7 +133,7 @@ TEST_F(GemmaTest, Multiturn) {
// Note: we do not rewind any <end_of_turn> tokens here. If the model // Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be // produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated. // duplicated.
mutable_prompt = "Please repeat all prior statements."; mutable_prompt = "Please repeat what I just told you.";
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
config.wrapping, abs_pos, mutable_prompt); config.wrapping, abs_pos, mutable_prompt);
@ -167,6 +170,9 @@ TEST_F(GemmaTest, CrossEntropySmall) {
case gcpp::Model::GEMMA2_27B: case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 1.30f, 0.02f); EXPECT_NEAR(entropy, 1.30f, 0.02f);
break; break;
case gcpp::Model::GEMMA3_270M:
EXPECT_NEAR(entropy, 1.41f, 0.02f);
break;
default: default:
FAIL() << "no entropy expectation for this model"; FAIL() << "no entropy expectation for this model";
break; break;
@ -178,7 +184,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv); gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS(); int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv(); gcpp::GemmaTest::DeleteEnv();

View File

@ -31,7 +31,9 @@
namespace gcpp { namespace gcpp {
struct JsonArgs : public ArgsBase<JsonArgs> { struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path input; Path input;
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.all"); PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::JsonArgs json(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::AbortIfInvalidArgs(json); gcpp::JsonArgs json_args(argc, argv, consumed);
gcpp::Run(env, json); gcpp::AbortIfInvalidArgs(json_args);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::Run(env, json_args);
} }
PROFILER_PRINT_RESULTS(); // Must call outside the zone above. PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0; return 0;

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -24,20 +24,20 @@
#include <vector> #include <vector>
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/gemma_args.h" // GemmaArgs
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "util/args.h" #include "util/args.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
loader.Help(); args.Help();
return 0; return 0;
} }
consumed.AbortIfUnconsumed();
// Demonstrate constrained decoding by never outputting certain tokens. // Demonstrate constrained decoding by never outputting certain tokens.
std::set<int> reject_tokens; std::set<int> reject_tokens;
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::ThreadingContext ctx(threading); gcpp::ThreadingContext ctx(args.threading);
gcpp::MatMulEnv env(ctx); gcpp::MatMulEnv env(ctx);
gcpp::Gemma gemma(loader, inference, ctx); gcpp::Gemma gemma(args, ctx);
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
size_t generated = 0; size_t generated = 0;
// Tokenize instructions. // Tokenize instructions.

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -23,7 +23,7 @@
#include <vector> #include <vector>
#include "third_party/gemma_cpp/gemma/gemma.h" #include "third_party/gemma_cpp/gemma/gemma.h"
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs #include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
#include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/gemma/tokenizer.h"
#include "third_party/gemma_cpp/ops/matmul.h" #include "third_party/gemma_cpp/ops/matmul.h"
#include "third_party/gemma_cpp/util/threading_context.h" #include "third_party/gemma_cpp/util/threading_context.h"
@ -31,18 +31,11 @@
class SimplifiedGemma { class SimplifiedGemma {
public: public:
SimplifiedGemma(const gcpp::LoaderArgs& loader, SimplifiedGemma(const gcpp::GemmaArgs& args)
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), : ctx_(args.threading),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: ctx_(threading),
env_(ctx_), env_(ctx_),
gemma_(loader, inference, ctx_), gemma_(args, ctx_),
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
gcpp::ThreadingArgs(argc, argv),
gcpp::InferenceArgs(argc, argv)) {}
void Generate(std::string& prompt, size_t max_generated_tokens = 1024, void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
float temperature = 0.7, float temperature = 0.7,

View File

@ -18,28 +18,18 @@
#include <string> #include <string>
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/gemma_args.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
// Standard usage: LoaderArgs takes argc and argv as input, then parses // Sets arguments from argc and argv. Note that you can instead pass in
// necessary flags. // LoaderArgs, ThreadingArgs, and InferenceArgs directly.
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
// Optional: LoaderArgs can also take tokenizer and weights paths directly. SimplifiedGemma gemma(args);
//
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
// "model_identifier");
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
// specified, default values will be used.
//
// gcpp::InferenceArgs inference(argc, argv);
// gcpp::ThreadingArgs threading(argc, argv);
// SimplifiedGemma gemma(loader, threading, inference);
SimplifiedGemma gemma(loader);
std::string prompt = "Write a greeting to the world."; std::string prompt = "Write a greeting to the world.";
gemma.Generate(prompt, 256, 0.6); gemma.Generate(prompt, 256, 0.6);
return 0; return 0;
} }

View File

@ -23,44 +23,54 @@
#include <atomic> #include <atomic>
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "ops/ops.h" // CreateInvTimescale #include "gemma/gemma_args.h" // AttentionImpl
#include "util/basics.h" // BF16 #include "gemma/kv_cache.h"
#include "util/mat.h" // MatStorageT #include "gemma/tensor_stats.h"
#include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
#include "util/threading_context.h" #include "util/threading_context.h"
namespace gcpp { namespace gcpp {
struct AttentionActivations { // Returns the scale value to use for the query in the attention computation.
// Returns the scale value to use for the query in the attention computation. // Also called by ops_test.
// Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) {
static inline float ChooseQueryScale(const ModelConfig& config) { const LayerConfig& layer_config = config.layer_configs[0];
const LayerConfig& layer_config = config.layer_configs[0]; if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) return 1.0f /
return 1.0f / sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
sqrtf(static_cast<float>(config.model_dim / layer_config.heads)); // QueryScaleType::SqrtKeySize
// QueryScaleType::SqrtKeySize return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim)); }
}
struct AttentionActivations {
AttentionActivations( AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config, const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator, size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config), : // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
q(MatFactory("q", batch_size, q(MatFactory("q", batch_size,
config.vocab_size == 0 config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim, : layer_config.heads * layer_config.qkv_dim,
allocator)), allocator)),
q_bf(MatFactory("q_bf", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim, q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0 config.vocab_size == 0
? batch_size * layer_config.heads * 3 ? batch_size * layer_config.heads * 3
: batch_size * layer_config.heads, : batch_size * layer_config.heads,
allocator)), allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)), config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len, att(MatFactory("att", batch_size, layer_config.heads * seq_len,
@ -68,6 +78,10 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size, att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim, layer_config.heads * layer_config.qkv_dim,
allocator)), allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
att_sums( att_sums(
MatFactory("att_sums", batch_size, config.model_dim, allocator)), MatFactory("att_sums", batch_size, config.model_dim, allocator)),
@ -76,11 +90,7 @@ struct AttentionActivations {
layer_config.post_qk == PostQKType::HalfRope)), layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale( inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim, allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
// Batch size can be 0 in experimental code so do not assert. // Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) { if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT; static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@ -94,44 +104,153 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs);
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
} }
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size); q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim! // q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size); att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size); att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
} }
const ModelConfig& config;
MatStorageT<float> q; // query MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed. MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> vit_Q;
MatStorageT<float> vit_K;
MatStorageT<float> vit_C;
MatStorageT<float> pre_att_rms_out; MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output MatStorageT<float> att_out; // attention output
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads // Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums; MatStorageT<BF16> att_sums;
// Rope // Rope
MatStorageT<float> inv_timescale; MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global; MatStorageT<float> inv_timescale_global;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
: config(config),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
vit_Q = activations.vit_Q;
vit_K = activations.vit_K;
vit_C = activations.vit_C;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}
size_t SeqLen() const {
return static_cast<size_t>(div_seq_len.GetDivisor());
}
const ModelConfig& config;
// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
// during PrefillTBatch, it is prompt length (up to some max batch size)
// and otherwise it's qbatch.Size().
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<float> q;
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<BF16> q_bf;
// Transposed query matrix for faster Q*K^T.
MatPtrT<BF16> q_T;
MatPtrT<float> vit_Q;
MatPtrT<float> vit_K;
MatPtrT<float> vit_C;
// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out;
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
// seq_len).
MatPtrT<float> att;
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_max;
// The sum of scaled exponentials when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_d;
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
MatPtrT<BF16> att_sums;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.
MatPtrT<float> inv_timescale_global;
// Divisor for faster division by sequence length.
hwy::Divisor div_seq_len; hwy::Divisor div_seq_len;
// Unfortunately, some models have had non-power-of-two heads. // Divisor for faster division by number of heads.
hwy::Divisor div_heads; hwy::Divisor div_heads;
// Query scaling factor for attention computation.
float query_scale; float query_scale;
}; };
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, Activations(const RuntimeConfig& runtime_config, const ModelConfig& config,
ThreadingContext& ctx, size_t batch_size, size_t seq_len, ThreadingContext& ctx,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]), : layer_config(config.layer_configs[0]),
@ -150,8 +269,18 @@ struct Activations {
ffw_out( ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, ctx.allocator, max_workers(ctx.pools.MaxWorkers()),
row_ptrs) { s_ffw_in(config.num_layers, max_workers),
s_ffw_hidden(config.num_layers, max_workers),
s_ffw_out(config.num_layers, max_workers),
s_w_gating_einsum_w1(config.num_layers, max_workers),
s_w_gating_einsum_w2(config.num_layers, max_workers),
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config, ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
@ -167,6 +296,12 @@ struct Activations {
// Note that BindC on any MatMul output considerably slows down Prefill. // Note that BindC on any MatMul output considerably slows down Prefill.
} }
~Activations() {
s_ffw_in.ReduceAndPrint("ffw_in");
s_ffw_hidden.ReduceAndPrint("ffw_hidden");
s_ffw_out.ReduceAndPrint("ffw_out");
}
// Negligible CPU time. // Negligible CPU time.
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size); x.OverrideRows(batch_size);
@ -179,12 +314,15 @@ struct Activations {
C2.OverrideRows(batch_size); C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size);
attention_storage.SetBatchSize(batch_size);
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
// their row override is not updated when the underlying storage changes.
attention.SetBatchSize(batch_size); attention.SetBatchSize(batch_size);
} }
const LayerConfig& layer_config; const LayerConfig& layer_config;
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that. MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded) MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
@ -195,7 +333,19 @@ struct Activations {
MatStorageT<BF16> C2; MatStorageT<BF16> C2;
MatStorageT<float> ffw_out; MatStorageT<float> ffw_out;
AttentionActivations attention; const size_t max_workers;
TensorStats s_ffw_in;
TensorStats s_ffw_hidden; // after Activation+gating
TensorStats s_ffw_out;
TensorStats s_w_gating_einsum_w1;
TensorStats s_w_gating_einsum_w2;
TensorStats s_w_linear_w;
AttentionImpl attention_impl;
AttentionActivations attention_storage;
AttentionActivationsPtrs attention;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -15,18 +15,22 @@
// Test client for API server // Test client for API server
#include <iostream> #include <stdio.h>
#include <string>
#include <sstream>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include "httplib.h" #include "httplib.h"
#include "nlohmann/json.hpp"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json; using json = nlohmann::json;
namespace gcpp {
// ANSI color codes // ANSI color codes
const std::string RESET = "\033[0m"; const std::string RESET = "\033[0m";
const std::string BOLD = "\033[1m"; const std::string BOLD = "\033[1m";
@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m";
const std::string RED = "\033[31m"; const std::string RED = "\033[31m";
class APIClient { class APIClient {
public: public:
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") APIClient(const std::string& host, int port, const std::string& api_key = "",
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { const std::string& model = "gemma3-4b")
: host_(host),
port_(port),
api_key_(api_key),
model_(model),
use_https_(port == 443),
interactive_mode_(false) {
if (use_https_) { if (use_https_) {
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port); ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
ssl_client_->set_read_timeout(60, 0); ssl_client_->set_read_timeout(60, 0);
@ -55,22 +65,25 @@ public:
// Unified request processing for both public and local APIs // Unified request processing for both public and local APIs
json ProcessRequest(const json& request, bool stream = true) { json ProcessRequest(const json& request, bool stream = true) {
bool is_public_api = !api_key_.empty(); bool is_public_api = !api_key_.empty();
std::string endpoint; std::string endpoint;
if (is_public_api) { if (is_public_api) {
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" endpoint =
: "/v1beta/models/gemini-2.0-flash:generateContent"; stream
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
: "/v1beta/models/gemini-2.0-flash:generateContent";
} else { } else {
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
: "/v1beta/models/" + model_ + ":generateContent"; : "/v1beta/models/" + model_ + ":generateContent";
} }
// Only show verbose output in non-interactive mode // Only show verbose output in non-interactive mode
if (!interactive_mode_) { if (!interactive_mode_) {
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; std::cout << "\n"
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
std::cout << "Request: " << request.dump(2) << std::endl; std::cout << "Request: " << request.dump(2) << std::endl;
} }
if (stream) { if (stream) {
return ProcessStreamingRequest(request, endpoint); return ProcessStreamingRequest(request, endpoint);
} else { } else {
@ -81,21 +94,24 @@ public:
void TestGenerateContent(const std::string& prompt, bool stream = true) { void TestGenerateContent(const std::string& prompt, bool stream = true) {
json request = CreateAPIRequest(prompt); json request = CreateAPIRequest(prompt);
json response = ProcessRequest(request, stream); json response = ProcessRequest(request, stream);
if (response.contains("error")) { if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
<< std::endl;
} }
} }
void TestListModels() { void TestListModels() {
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; std::cout << "\n"
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
httplib::Headers headers; httplib::Headers headers;
if (!api_key_.empty()) { if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_); headers.emplace("X-goog-api-key", api_key_);
} }
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
: client_->Get("/v1beta/models", headers);
if (res && res->status == 200) { if (res && res->status == 200) {
json response = json::parse(res->body); json response = json::parse(res->body);
std::cout << GREEN << "✅ Available models:" << RESET << std::endl; std::cout << GREEN << "✅ Available models:" << RESET << std::endl;
@ -106,49 +122,53 @@ public:
} }
void InteractiveChat() { void InteractiveChat() {
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; std::cout << "\n"
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
<< RESET << std::endl;
std::cout << "Type ':gemma %q' to end.\n" << std::endl; std::cout << "Type ':gemma %q' to end.\n" << std::endl;
interactive_mode_ = true; interactive_mode_ = true;
json messages; json messages;
while (true) { while (true) {
std::cout << BOLD << BLUE << "You: " << RESET; std::cout << BOLD << BLUE << "You: " << RESET;
std::string input; std::string input;
std::getline(std::cin, input); std::getline(std::cin, input);
if (input == ":gemma %q") { if (input == ":gemma %q") {
std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl; std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl;
break; break;
} }
if (input.empty()) continue; if (input.empty()) continue;
// Add user message with proper role // Add user message with proper role
json user_message = {{"parts", {{{"text", input}}}}}; json user_message = {{"parts", {{{"text", input}}}}};
if (!api_key_.empty()) { if (!api_key_.empty()) {
user_message["role"] = "user"; user_message["role"] = "user";
} }
messages.push_back(user_message); messages.push_back(user_message);
// Create request using unified logic // Create request using unified logic
json request = CreateAPIRequest("", messages); json request = CreateAPIRequest("", messages);
std::cout << BOLD << GREEN << "Assistant: " << RESET; std::cout << BOLD << GREEN << "Assistant: " << RESET;
// Use unified processing - streaming for real-time output // Use unified processing - streaming for real-time output
json response = ProcessRequest(request, true); json response = ProcessRequest(request, true);
if (response.contains("candidates") && !response["candidates"].empty()) { if (response.contains("candidates") && !response["candidates"].empty()) {
auto& candidate = response["candidates"][0]; auto& candidate = response["candidates"][0];
if (candidate.contains("content") && candidate["content"].contains("parts")) { if (candidate.contains("content") &&
candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) { for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) { if (part.contains("text")) {
std::string assistant_response = part["text"].get<std::string>(); std::string assistant_response = part["text"].get<std::string>();
// For streaming, the response is already displayed in real-time // For streaming, the response is already displayed in real-time
// Just add to message history for context // Just add to message history for context
json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; json assistant_message = {
{"parts", {{{"text", assistant_response}}}}};
if (!api_key_.empty()) { if (!api_key_.empty()) {
assistant_message["role"] = "model"; assistant_message["role"] = "model";
} }
@ -157,23 +177,21 @@ public:
} }
} }
} else if (response.contains("error")) { } else if (response.contains("error")) {
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; std::cerr << RED << "❌ Error: " << response["error"]["message"]
<< RESET << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
} }
} }
private: private:
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { json CreateAPIRequest(const std::string& prompt,
const json& messages = json::array()) {
json request = { json request = {
{"generationConfig", { {"generationConfig",
{"temperature", 0.9}, {{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
{"topK", 1},
{"maxOutputTokens", 1024}
}}
};
if (messages.empty()) { if (messages.empty()) {
// Single prompt // Single prompt
json user_message = {{"parts", {{{"text", prompt}}}}}; json user_message = {{"parts", {{{"text", prompt}}}}};
@ -185,44 +203,48 @@ private:
// Use provided message history // Use provided message history
request["contents"] = messages; request["contents"] = messages;
} }
return request; return request;
} }
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { json ProcessNonStreamingRequest(const json& request,
const std::string& endpoint) {
httplib::Headers headers = {{"Content-Type", "application/json"}}; httplib::Headers headers = {{"Content-Type", "application/json"}};
if (!api_key_.empty()) { if (!api_key_.empty()) {
headers.emplace("X-goog-api-key", api_key_); headers.emplace("X-goog-api-key", api_key_);
} }
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
: client_->Post(endpoint, headers, request.dump(), "application/json"); "application/json")
: client_->Post(endpoint, headers, request.dump(),
"application/json");
if (res && res->status == 200) { if (res && res->status == 200) {
json response = json::parse(res->body); json response = json::parse(res->body);
if (!interactive_mode_) { if (!interactive_mode_) {
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; std::cout << "\n"
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
std::cout << response.dump(2) << std::endl; std::cout << response.dump(2) << std::endl;
} }
return response; return response;
} else { } else {
json error_response = { json error_response = {{"error",
{"error", { {{"message", "Request failed"},
{"message", "Request failed"}, {"status", res ? res->status : -1}}}};
{"status", res ? res->status : -1}
}}
};
if (res && !res->body.empty()) { if (res && !res->body.empty()) {
error_response["error"]["details"] = res->body; error_response["error"]["details"] = res->body;
} }
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; std::cerr << RED
<< "❌ Request failed. Status: " << (res ? res->status : -1)
<< RESET << std::endl;
return error_response; return error_response;
} }
} }
json ProcessStreamingRequest(const json& request, const std::string& endpoint) { json ProcessStreamingRequest(const json& request,
const std::string& endpoint) {
std::string accumulated_response; std::string accumulated_response;
// Use same SSE logic for both public and local APIs // Use same SSE logic for both public and local APIs
httplib::Request req; httplib::Request req;
req.method = "POST"; req.method = "POST";
@ -232,72 +254,73 @@ private:
req.set_header("X-goog-api-key", api_key_); req.set_header("X-goog-api-key", api_key_);
} }
req.body = request.dump(); req.body = request.dump();
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { req.content_receiver = [&accumulated_response, this](
std::string chunk(data, data_length); const char* data, size_t data_length,
std::istringstream stream(chunk); uint64_t offset, uint64_t total_length) -> bool {
std::string line; std::string chunk(data, data_length);
std::istringstream stream(chunk);
while (std::getline(stream, line)) { std::string line;
if (line.substr(0, 6) == "data: ") {
std::string event_data = line.substr(6); while (std::getline(stream, line)) {
if (line.substr(0, 6) == "data: ") {
if (event_data == "[DONE]") { std::string event_data = line.substr(6);
if (!interactive_mode_) {
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; if (event_data == "[DONE]") {
} if (!interactive_mode_) {
} else { std::cout << "\n\n"
try { << GREEN << "✅ Generation complete!" << RESET
json event = json::parse(event_data); << std::endl;
if (event.contains("candidates") && !event["candidates"].empty()) { }
auto& candidate = event["candidates"][0]; } else {
if (candidate.contains("content") && candidate["content"].contains("parts")) { try {
for (const auto& part : candidate["content"]["parts"]) { json event = json::parse(event_data);
if (part.contains("text")) { if (event.contains("candidates") &&
std::string text = part["text"].get<std::string>(); !event["candidates"].empty()) {
std::cout << text << std::flush; auto& candidate = event["candidates"][0];
accumulated_response += text; if (candidate.contains("content") &&
} candidate["content"].contains("parts")) {
for (const auto& part : candidate["content"]["parts"]) {
if (part.contains("text")) {
std::string text = part["text"].get<std::string>();
std::cout << text << std::flush;
accumulated_response += text;
} }
} }
} }
} catch (const json::exception& e) {
// Skip parse errors
} }
} catch (const json::exception& e) {
// Skip parse errors
} }
} }
} }
return true; }
}; return true;
};
httplib::Response res; httplib::Response res;
httplib::Error error; httplib::Error error;
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); bool success = use_https_ ? ssl_client_->send(req, res, error)
: client_->send(req, res, error);
if (res.status == 200 && !accumulated_response.empty()) { if (res.status == 200 && !accumulated_response.empty()) {
return json{ return json{
{"candidates", {{ {"candidates",
{"content", { {{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
{"parts", {{{"text", accumulated_response}}}}
}}
}}}
};
} else { } else {
json error_response = { json error_response = {
{"error", { {"error",
{"message", "Streaming request failed"}, {{"message", "Streaming request failed"}, {"status", res.status}}}};
{"status", res.status}
}}
};
if (!res.body.empty()) { if (!res.body.empty()) {
error_response["error"]["details"] = res.body; error_response["error"]["details"] = res.body;
} }
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
<< RESET << std::endl;
return error_response; return error_response;
} }
} }
private: private:
std::unique_ptr<httplib::Client> client_; std::unique_ptr<httplib::Client> client_;
std::unique_ptr<httplib::SSLClient> ssl_client_; std::unique_ptr<httplib::SSLClient> ssl_client_;
std::string host_; std::string host_;
@ -308,19 +331,55 @@ private:
bool interactive_mode_; bool interactive_mode_;
}; };
struct ClientArgs : public ArgsBase<ClientArgs> {
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ClientArgs() { Init(); };
std::string host;
int port;
std::string api_key;
std::string model;
std::string prompt;
bool interactive;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(host, "host", std::string("localhost"),
"Server host (default: localhost)");
visitor(port, "port", 8080, "Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
}
};
} // namespace gcpp
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gcpp::ClientArgs client_args(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ClientArgs client_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cout << "\nAPI Client for gemma.cpp\n"; fprintf(stderr,
std::cout << "========================\n\n"; "\nAPI Client for gemma.cpp\n"
"========================\n\n");
client_args.Help(); client_args.Help();
std::cout << std::endl; fprintf(stderr,
std::cout << "Environment Variables:" << std::endl; "\n*Environment Variables:\n"
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; " GOOGLE_API_KEY : Automatically use public Google API if set\n");
return 0; return 0;
} }
consumed.AbortIfUnconsumed();
// Check for GOOGLE_API_KEY environment variable // Check for GOOGLE_API_KEY environment variable
const char* env_api_key = std::getenv("GOOGLE_API_KEY"); const char* env_api_key = std::getenv("GOOGLE_API_KEY");
if (env_api_key != nullptr && strlen(env_api_key) > 0) { if (env_api_key != nullptr && strlen(env_api_key) > 0) {
@ -328,32 +387,34 @@ int main(int argc, char* argv[]) {
client_args.host = "generativelanguage.googleapis.com"; client_args.host = "generativelanguage.googleapis.com";
client_args.port = 443; client_args.port = 443;
} }
// Handle API key override // Handle API key override
if (!client_args.api_key.empty()) { if (!client_args.api_key.empty()) {
client_args.host = "generativelanguage.googleapis.com"; client_args.host = "generativelanguage.googleapis.com";
client_args.port = 443; client_args.port = 443;
} }
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
<< client_args.host << ":" << client_args.port << RESET << std::endl; << ":" << client_args.port << RESET << std::endl;
try { try {
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); APIClient client(client_args.host, client_args.port, client_args.api_key,
client_args.model);
if (client_args.interactive) { if (client_args.interactive) {
client.InteractiveChat(); client.InteractiveChat();
} else { } else {
client.TestListModels(); client.TestListModels();
client.TestGenerateContent(client_args.prompt, true); client.TestGenerateContent(client_args.prompt, true);
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
std::cerr << "Make sure the API server is running:" << std::endl; std::cerr << "Make sure the API server is running:" << std::endl;
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl; std::cerr
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
<< std::endl;
return 1; return 1;
} }
return 0; return 0;
} }

View File

@ -15,22 +15,19 @@
// HTTP API server for gemma.cpp with SSE support // HTTP API server for gemma.cpp with SSE support
#include <stdio.h>
#include <signal.h> #include <signal.h>
#include <stdio.h>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <string_view>
#include <vector>
#include <thread>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <sstream> #include <iostream>
#include <iomanip> #include <memory>
#include <mutex> #include <mutex>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map> #include <unordered_map>
#include <vector>
// HTTP server library // HTTP server library
#undef CPPHTTPLIB_OPENSSL_SUPPORT #undef CPPHTTPLIB_OPENSSL_SUPPORT
@ -38,16 +35,12 @@
#include "httplib.h" #include "httplib.h"
// JSON library // JSON library
#include "nlohmann/json.hpp"
#include "compression/types.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/tokenizer.h"
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/profiler.h" #include "nlohmann/json.hpp"
using json = nlohmann::json; using json = nlohmann::json;
@ -90,7 +83,8 @@ struct ServerState {
std::lock_guard<std::mutex> lock(sessions_mutex); std::lock_guard<std::mutex> lock(sessions_mutex);
auto& session = sessions[session_id]; auto& session = sessions[session_id];
if (!session.kv_cache) { if (!session.kv_cache) {
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator); session.kv_cache = std::make_unique<KVCache>(
gemma->Config(), InferenceArgs(), env->ctx.allocator);
} }
session.last_access = std::chrono::steady_clock::now(); session.last_access = std::chrono::steady_clock::now();
return session; return session;
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
return ss.str(); return ss.str();
} }
// Wraps messages with start_of_turn markers - handles both with and without roles // Wraps messages with start_of_turn markers - handles both with and without
// roles
std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string prompt; std::string prompt;
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
std::string text = part["text"]; std::string text = part["text"];
if (role == "user") { if (role == "user") {
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n"; prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} else if (role == "model") { } else if (role == "model") {
prompt += text + "\n"; prompt += text + "\n";
} else if (role.empty()) { } else if (role.empty()) {
// Local format without roles - for now, treat as user input // Local format without roles - for now, treat as user input
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n"; prompt +=
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
} }
} }
} }
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
return config; return config;
} }
// Unified response formatter - creates consistent format regardless of request type // Unified response formatter - creates consistent format regardless of request
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { // type
json CreateAPIResponse(const std::string& text,
bool is_streaming_chunk = false) {
json response = { json response = {
{"candidates", {{ {"candidates",
{"content", { {{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
{"parts", {{{"text", text}}}}, {"index", 0}}}},
{"role", "model"} {"promptFeedback", {{"safetyRatings", json::array()}}}};
}},
{"index", 0}
}}},
{"promptFeedback", {{"safetyRatings", json::array()}}}
};
// Only add finishReason for non-streaming chunks // Only add finishReason for non-streaming chunks
if (!is_streaming_chunk) { if (!is_streaming_chunk) {
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
} }
// Handle generateContent endpoint (non-streaming) // Handle generateContent endpoint (non-streaming)
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { void HandleGenerateContentNonStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try { try {
json request = json::parse(req.body); json request = json::parse(req.body);
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
prompt = WrapMessagesWithTurnMarkers(request["contents"]); prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else { } else {
res.status = 400; res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return; return;
} }
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Set up runtime config // Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request); RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Collect full response runtime_config.stream_token = [](int token, float) { return true; };
std::string full_response;
runtime_config.stream_token = [&full_response](int token, float) {
// Skip EOS token
return true;
};
// Tokenize prompt // Tokenize prompt
std::vector<int> tokens = WrapAndTokenize( std::vector<int> tokens = WrapAndTokenize(
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
// Temporarily redirect output to capture response // Temporarily redirect output to capture response
std::stringstream output; std::stringstream output;
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { runtime_config.stream_token = [&output, &state, &session, &tokens](
int token, float) {
// Skip prompt tokens // Skip prompt tokens
if (session.abs_pos < tokens.size()) { if (session.abs_pos < tokens.size()) {
session.abs_pos++; session.abs_pos++;
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
} }
// Handle streamGenerateContent endpoint with SSE) // Handle streamGenerateContent endpoint with SSE)
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { void HandleGenerateContentStreaming(ServerState& state,
const httplib::Request& req,
httplib::Response& res) {
try { try {
json request = json::parse(req.body); json request = json::parse(req.body);
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
prompt = WrapMessagesWithTurnMarkers(request["contents"]); prompt = WrapMessagesWithTurnMarkers(request["contents"]);
} else { } else {
res.status = 400; res.status = 400;
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); res.set_content(
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
"application/json");
return; return;
} }
@ -305,88 +303,85 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
// Set up chunked content provider for SSE // Set up chunked content provider for SSE
res.set_chunked_content_provider( res.set_chunked_content_provider(
"text/event-stream", "text/event-stream", [&state, request, prompt, session_id](
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { size_t offset, httplib::DataSink& sink) {
try { try {
// Lock for inference // Lock for inference
std::lock_guard<std::mutex> lock(state.inference_mutex); std::lock_guard<std::mutex> lock(state.inference_mutex);
auto& session = state.GetOrCreateSession(session_id); auto& session = state.GetOrCreateSession(session_id);
// Set up runtime config // Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request); RuntimeConfig runtime_config = ParseGenerationConfig(request);
// Tokenize prompt // Tokenize prompt
std::vector<int> tokens = WrapAndTokenize( std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(), state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt); state.gemma->Config().wrapping, session.abs_pos, prompt);
// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
return true;
}
// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++; session.abs_pos++;
// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
return true;
}
// Decode token
std::string token_text;
state.gemma->Tokenizer().Decode(std::vector<int>{token},
&token_text);
accumulated_text += token_text;
// Send SSE event using unified formatter
json event = CreateAPIResponse(token_text, true);
std::string sse_data = "data: " + event.dump() + "\n\n";
sink.write(sse_data.data(), sse_data.size());
return true; return true;
} };
session.abs_pos++; runtime_config.stream_token = stream_token;
// Check for EOS // Run inference with KV cache
if (state.gemma->Config().IsEOS(token)) { TimingInfo timing_info = {.verbosity = 0};
return true; size_t prefix_end = 0;
}
// Decode token state.gemma->Generate(runtime_config, tokens, session.abs_pos,
std::string token_text; prefix_end, *session.kv_cache, *state.env,
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text); timing_info);
accumulated_text += token_text;
// Send SSE event using unified formatter // Send final event using unified formatter
json event = CreateAPIResponse(token_text, true); json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}};
std::string sse_data = "data: " + event.dump() + "\n\n"; std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(sse_data.data(), sse_data.size()); sink.write(final_sse.data(), final_sse.size());
return true; // Send done event
}; sink.write("data: [DONE]\n\n", 15);
runtime_config.stream_token = stream_token;
// Run inference with KV cache
TimingInfo timing_info = {.verbosity = 0};
size_t prefix_end = 0;
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
prefix_end, *session.kv_cache, *state.env,
timing_info);
// Send final event using unified formatter
json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
};
std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(final_sse.data(), final_sse.size());
// Send done event
sink.write("data: [DONE]\n\n", 15);
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
}
);
// Ensure all data is sent
sink.done();
return false; // End streaming
} catch (const std::exception& e) {
json error_event = {{"error", {{"message", e.what()}}}};
std::string error_sse = "data: " + error_event.dump() + "\n\n";
sink.write(error_sse.data(), error_sse.size());
return false;
}
});
} catch (const json::exception& e) { } catch (const json::exception& e) {
res.status = 400; res.status = 400;
res.set_content( res.set_content(
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
} }
// Handle models list endpoint // Handle models list endpoint
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { void HandleListModels(ServerState& state, const InferenceArgs& inference,
const httplib::Request& req, httplib::Response& res) {
json response = { json response = {
{"models", {{ {"models",
{"name", "models/" + inference.model}, {{{"name", "models/" + inference.model},
{"version", "001"}, {"version", "001"},
{"displayName", inference.model}, {"displayName", inference.model},
{"description", inference.model + " model running locally"}, {"description", inference.model + " model running locally"},
{"inputTokenLimit", 8192}, {"inputTokenLimit", 8192},
{"outputTokenLimit", 8192}, {"outputTokenLimit", 8192},
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, {"supportedGenerationMethods",
{"temperature", 1.0}, json::array({"generateContent", "streamGenerateContent"})},
{"topK", 1} {"temperature", 1.0},
}}} {"topK", 1}}}}};
};
res.set_content(response.dump(), "application/json"); res.set_content(response.dump(), "application/json");
} }
@ -421,39 +416,45 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
// server_running = false; // server_running = false;
// } // }
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, void RunServer(const GemmaArgs& args) {
const InferenceArgs& inference) {
std::cerr << "Loading model..." << std::endl; std::cerr << "Loading model..." << std::endl;
// Initialize model // Initialize model
ThreadingContext ctx(threading); ThreadingContext ctx(args.threading);
MatMulEnv env(ctx); MatMulEnv env(ctx);
ServerState state; ServerState state;
state.gemma = std::make_unique<Gemma>(loader, inference, ctx); state.gemma = std::make_unique<Gemma>(args, ctx);
state.env = &env; state.env = &env;
state.ctx = &ctx; state.ctx = &ctx;
const InferenceArgs& inference = args.inference;
httplib::Server server; httplib::Server server;
// Set up routes // Set up routes
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { server.Get(
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); "/", [&inference](const httplib::Request&, httplib::Response& res) {
}); res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
inference.model + ":generateContent",
"text/plain");
});
// API endpoints // API endpoints
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
httplib::Response& res) {
HandleListModels(state, inference, req, res); HandleListModels(state, inference, req, res);
}); });
std::string model_endpoint = "/v1beta/models/" + inference.model; std::string model_endpoint = "/v1beta/models/" + inference.model;
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { server.Post(model_endpoint + ":generateContent",
HandleGenerateContentNonStreaming(state, req, res); [&state](const httplib::Request& req, httplib::Response& res) {
}); HandleGenerateContentNonStreaming(state, req, res);
});
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { server.Post(model_endpoint + ":streamGenerateContent",
HandleGenerateContentStreaming(state, req, res); [&state](const httplib::Request& req, httplib::Response& res) {
}); HandleGenerateContentStreaming(state, req, res);
});
// Periodic cleanup of old sessions // Periodic cleanup of old sessions
std::thread cleanup_thread([&state]() { std::thread cleanup_thread([&state]() {
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Starting API server on port " << inference.port << std::endl;
std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Model loaded successfully" << std::endl;
std::cerr << "Endpoints:" << std::endl; std::cerr << "Endpoints:" << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; << std::endl;
std::cerr << " POST /v1beta/models/" << inference.model
<< ":streamGenerateContent (SSE)" << std::endl;
std::cerr << " GET /v1beta/models" << std::endl; std::cerr << " GET /v1beta/models" << std::endl;
if (!server.listen("0.0.0.0", inference.port)) { if (!server.listen("0.0.0.0", inference.port)) {
std::cerr << "Failed to start server on port " << inference.port << std::endl; std::cerr << "Failed to start server on port " << inference.port
<< std::endl;
} }
cleanup_thread.join(); cleanup_thread.join();
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::InternalInit(); gcpp::InternalInit();
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cerr << "\n\nAPI server for gemma.cpp\n"; fprintf(
std::cout << "========================\n\n"; stderr,
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n"; "\n\nAPI server for gemma.cpp\n"
std::cerr << "\nOptions:\n"; "========================\n\n"
std::cerr << " --port PORT Server port (default: 8080)\n"; " --port PORT Server port (default: 8080)\n"
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; " --model MODEL Model name for endpoints (default: gemma3-4b)\n");
std::cerr << "\n"; args.Help();
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
return 0; return 0;
} }
// Arguments are now handled by InferenceArgs consumed.AbortIfUnconsumed();
// // Set up signal handler // // Set up signal handler
// signal(SIGINT, gcpp::HandleShutdown); // signal(SIGINT, gcpp::HandleShutdown);
// signal(SIGTERM, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown);
gcpp::RunServer(loader, threading, inference); gcpp::RunServer(args);
return 0; return 0;
} }

View File

@ -43,6 +43,7 @@
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/flash_attention.h" #include "gemma/flash_attention.h"
#include "gemma/gemma-inl.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
@ -57,33 +58,35 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att, const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
ThreadingContext& ctx, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { const hn::ScalableTag<BF16> dbf;
// Slightly faster: no wraparound. const size_t qkv_dim = k.Cols();
for (size_t pos = start_pos; pos <= last_pos; ++pos) { HWY_ALIGN BF16 q_bf[kMaxQKVDim];
const float score = Dot(q, k.Row(pos), k.Cols());
att[pos] = score; CompressPerThread tls;
} const hn::ScalableTag<float> df;
} else { CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
for (size_t pos = start_pos; pos <= last_pos; ++pos) { 0);
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols()); // --seq_len must be large enough to avoid wraparound.
att[pos_modulo] = score; HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
} for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
att[pos] = score;
} }
} }
void PositionalEncodingQK(float* qk, const size_t layer_idx, void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivationsPtrs& activations,
const AttentionActivations& activations,
ThreadingContext& ctx, const size_t worker, ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) { const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const PostQKType& post_qk = layer.layer_config.post_qk; const size_t qkv_dim = layer_config.qkv_dim;
const PostQKType& post_qk = layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1(); const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model. if (is_global_layer && activations.config.use_global_timescale) {
if (is_global_layer && IsVLM(activations.config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1(); inv_timescale = activations.inv_timescale_global.PackedScale1();
} }
// PostQKType::Rope // PostQKType::Rope
@ -104,62 +107,52 @@ static HWY_INLINE void WeightedSumV(
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { // --seq_len must be large enough to avoid wraparound.
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
// we supported non-transposed B. // TODO: replace with MatMul(att, v) after it supports non-transposed B.
// TODO: 2..4x unroll MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, worker);
worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
}
} else {
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
}
} }
} }
// Calculates the attention outputs for a single q, which may be updated // Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm. // in place for RMSNorm.
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivations& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, const SMOptions& sm_options,
ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = // --seq_len must be large enough to avoid wraparound.
static_cast<size_t>(activations.div_seq_len.GetDivisor()); HWY_DASSERT(kv_last_pos < activations.SeqLen());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, ctx, worker); layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos, PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
query_scale); query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
worker);
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const Logits logits(att, kv_last_pos + 1);
const Logits logits(att, att_len);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker); MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f); Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
ctx, worker); att_out, ctx, worker);
} }
// The attention window usually starts at 0 unless `pos` is larger than // The attention window usually starts at 0 unless `pos` is larger than
@ -170,13 +163,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
} }
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const MatPtr& query_norm_scale,
AttentionActivations& activations, QBatch& qbatch, AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
@ -184,8 +177,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len = const size_t seq_len = activations.SeqLen();
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// All layers should have the same number of heads. // All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
@ -196,12 +188,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const size_t qi = div_qbatch.Remainder(tq_idx); const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx); const size_t token_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
// Find the token position in the query and calculate // Find the token position in the query and calculate
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + token_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx); const size_t start_pos = StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos; size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi); const size_t prefix_end = qbatch.PrefixEnd(qi);
@ -214,6 +206,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim; activations.att_out.Row(tq_idx) + head * qkv_dim;
SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head,
.d_out = activations.softmax_d.Row(tq_idx) + head};
// Make strided read-only views into the kv cache for // Make strided read-only views into the kv cache for
// this query and head. // this query and head.
@ -224,8 +218,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim)); MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, constexpr size_t offset = 0; // placeholder, do not remove
layer, activations, att, att_out, ctx, worker); SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att,
att_out, sm_options, ctx, worker);
}; };
{ {
@ -246,7 +242,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache. // Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags, const QBatch& qbatch, const int flags,
MatMulEnv& env) { MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
@ -271,10 +267,14 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
layer.qkv_einsum_w2.Rows())); layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
// Index into qbatch, within [0, qbatch.Size()]
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); // Index along token sequence, within [0, num_tokens)
const size_t cache_pos = const size_t token_idx = div_qbatch.Divide(interleaved_idx);
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); const size_t cache_pos = qbatch.Pos(qi) + token_idx;
// --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>( env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
} }
@ -286,15 +286,16 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Note that 2D parallelism is not worth the fork/join overhead because the // Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight. // tasks are very lightweight.
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV, /*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR { [&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t token_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = qbatch.Pos(qi) + token_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos); // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size + layer_idx * cache_layer_size +
@ -313,35 +314,18 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
}); });
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx, constexpr size_t offset = 0; // placeholder, do not remove
worker, pos, /*mul=*/1.0f); PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
cache_pos + offset,
/*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });
} }
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
activations.att_sums);
}
void GemmaAttention(size_t num_tokens, const size_t layer_idx, void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) { MatMulEnv& env, int flags) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention); GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
@ -353,13 +337,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kAttentionUseOld) { if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
env.ctx); activations, qbatch, env.ctx);
} else { } else {
// * 2 does not help on Turin. // * 2 does not help on Turin.
FlashAttention(num_tokens, FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx); layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx);
} }
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -29,8 +29,7 @@ namespace gcpp {
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \ void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const AttentionActivationsPtrs& activations, \
const AttentionActivations& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \ ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \ float mul); \
\ \
@ -39,18 +38,18 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const MatPtr& query_norm_scale, \
AttentionActivations& activations, \ AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \ MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

570
gemma/attention_test.cc Normal file
View File

@ -0,0 +1,570 @@
#include <cstddef>
#include <cstring> // strcmp
#include <memory>
#include <numeric>
#include <optional>
#include <vector>
#include "gtest/gtest.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS
// These tests aren't designed to suss out instruction set specific problems.
// Disable most targets to keep the tests fast and simple and not have to
// worry about tolerances on floating point results.
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/attention_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "gemma/configs.h"
#include "util/test_util.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void FillRandom(MatPtrT<float>& mat, uint64_t seed) {
hwy::RandomState rng(seed);
for (size_t r = 0; r < mat.Rows(); ++r) {
float* row = mat.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = static_cast<float>(RandomGaussian(rng));
}
}
}
void AllocateAndFillRandom(MatPtr& mat, const Allocator& allocator,
std::vector<MatOwner>& mat_owners, uint64_t seed) {
if (mat.IsEmpty()) return;
if (mat.GetType() == Type::kUnknown) {
mat.SetType(Type::kF32);
}
mat_owners.emplace_back();
mat_owners.back().AllocateFor(mat, allocator, MatPadding::kPacked);
MatPtrT<float> mat_f32(mat);
FillRandom(mat_f32, seed);
}
struct TestState {
TestState() : ctx({}), env(ctx) {}
ThreadingContext ctx;
std::vector<MatOwner> mat_owners;
MatMulEnv env;
};
struct TestModelState {
TestModelState(TestState& state)
: config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT),
tensor_info_registry(config),
layer_config(config.layer_configs[0]),
layer(0, layer_config, tensor_info_registry) {
config.att_cap = 1024.0f;
AllocateAndFillRandom(layer.qkv_einsum_w, state.ctx.allocator,
state.mat_owners, 42);
AllocateAndFillRandom(layer.attn_vec_einsum_w, state.ctx.allocator,
state.mat_owners, 43);
AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator,
state.mat_owners, 44);
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners,
45);
layer.Fixup(state.mat_owners, state.ctx);
}
ModelConfig config;
TensorInfoRegistry tensor_info_registry;
const LayerConfig& layer_config;
LayerWeightsPtrs layer;
};
struct TestAttentionState {
TestAttentionState(TestState& state, TestModelState& model_state,
size_t num_tokens, size_t qbatch_size,
AttentionImpl attention_impl)
: num_tokens(num_tokens),
qbatch_size(qbatch_size),
batch_size(qbatch_size * num_tokens),
runtime_config{.attention_impl = attention_impl},
tokens(num_tokens),
attention_storage_(model_state.config, model_state.layer_config,
batch_size, num_tokens, runtime_config,
state.ctx.allocator, row_ptrs_),
attention(model_state.config, num_tokens, attention_storage_) {
for (size_t i = 0; i < qbatch_size; ++i) {
kv_caches.emplace_back(model_state.config, inference_args,
state.ctx.allocator);
}
activations.emplace(
runtime_config, model_state.config, runtime_config.prefill_tbatch_size,
kv_caches[0].SeqLen(), state.env.ctx, state.env.row_ptrs);
// Tokens don't matter, since we fill in pre_att_rms_out before calling
// GemmaAttention.
std::iota(tokens.begin(), tokens.end(), 1);
for (size_t i = 0; i < qbatch_size; ++i) {
prompts.emplace_back(tokens);
}
all_queries.emplace(prompts,
hwy::Span<KVCache>(kv_caches.data(), kv_caches.size()));
qbatch.emplace(/*start=*/0, /*max_size=*/qbatch_size, *all_queries);
FillRandom(attention.pre_att_rms_out, 46);
}
const size_t num_tokens;
const size_t qbatch_size;
const size_t batch_size;
InferenceArgs inference_args;
RuntimeConfig runtime_config;
std::vector<KVCache> kv_caches;
std::optional<Activations> activations;
std::vector<int> tokens;
std::vector<PromptTokens> prompts;
std::optional<AllQueries> all_queries;
std::optional<QBatch> qbatch;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs_;
AttentionActivations attention_storage_;
AttentionActivationsPtrs attention;
};
double GetTolerance() {
const char* target_name = hwy::TargetName(HWY_TARGET);
if (strncmp(target_name, "AVX2", 4) == 0) {
return 2e-2;
} else if (strncmp(target_name, "AVX3", 4) == 0) {
return 3e-4;
} else if (strncmp(target_name, "NEON", 4) == 0) {
return 5e-3;
} else {
return 1e-7;
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareAttSumsWithGolden(
const AttentionActivationsPtrs& attention,
const float (&golden)[kNumTokens][kQBatchSize][kDims]) {
ASSERT_EQ(attention.att_sums.Rows(), kNumTokens * kQBatchSize);
ASSERT_LE(kDims, attention.att_sums.Cols());
hwy::AlignedFreeUniquePtr<float[]> actual_row =
hwy::AllocateAligned<float>(kDims);
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const size_t i = token_idx * kQBatchSize + qi;
for (size_t j = 0; j < kDims; ++j) {
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
}
EXPECT_TRUE(hwy::CompareArraySimilar(
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
}
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareKVCacheWithGolden(
const ModelConfig& config, hwy::Span<KVCache> kv_caches, const size_t layer,
const size_t kv_head,
const float (&k_golden)[kNumTokens][kQBatchSize][kDims],
const float (&v_golden)[kNumTokens][kQBatchSize][kDims]) {
const size_t qbatch_size = kv_caches.size();
ASSERT_EQ(kQBatchSize, qbatch_size);
const size_t start_offset = 0;
const size_t qkv_dim = config.layer_configs[0].qkv_dim;
hwy::AlignedFreeUniquePtr<float[]> actual_k_row =
hwy::AllocateAligned<float>(kDims);
hwy::AlignedFreeUniquePtr<float[]> actual_v_row =
hwy::AllocateAligned<float>(kDims);
const size_t cache_layer_size = config.layer_configs[layer].CacheLayerSize();
const size_t head_offset = kv_head * qkv_dim * 2;
const size_t kv_offset = layer * cache_layer_size + head_offset;
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const float* cache_row =
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
for (size_t j = 0; j < kDims; ++j) {
actual_k_row[j] = cache_row[kv_offset + j];
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
}
EXPECT_TRUE(hwy::CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
EXPECT_TRUE(hwy::CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
}
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareQVecsWithGolden(
const ModelConfig& config, const AttentionActivationsPtrs& attention,
const size_t q_head,
const float (&q_golden)[kNumTokens][kQBatchSize][kDims]) {
const size_t qkv_dim = config.layer_configs[0].qkv_dim;
hwy::AlignedFreeUniquePtr<float[]> actual_q_row =
hwy::AllocateAligned<float>(kDims);
const size_t head_offset = q_head * qkv_dim;
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const size_t i = token_idx * kQBatchSize + qi;
const float* q_row = attention.q.Row(i);
for (size_t j = 0; j < kDims; ++j) {
actual_q_row[j] = q_row[head_offset + j];
}
EXPECT_TRUE(hwy::CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
<< " q_head=" << q_head;
}
}
}
constexpr size_t kNumTokens = 10;
constexpr size_t kQBatchSize = 2;
// Only compare the first 10 dimensions of the attention output to keep the
// test manageable. It's unlikely that the first few dimensions would have
// different behavior than the later dimensions.
const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats
// Use a Coding Agent to update the golden values from test
// failure output (which includes the new values).
// Instructions to the Coding Agent: You must never update the
// EMU128 tolerance. Always use the value from the EMU128 test to update the
// Goldens. If necessary, add relaxed tolerance for other instruction sets.
// Layer 0
const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5,
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
128, 27.25, -161, 19.125, -58, 97.5},
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
-46, -22, -19.375, -16.125, -148, 20.875},
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
10.9375, -52.5, 23.25, 75}},
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
-34.75, 18, -52, 100, -186, -75.5, 50.75},
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
39.25, 65, 47.25, -89.5, -34.25, 137}},
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
88.5, 48.5, -46.25, -36.75, 101.5}},
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
125, -53.5, -14.625, 262},
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
81, -27.25, -3.03125, 111, -167, 59, 176}},
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
58.25, 173, -53.5, 25.125, 4.8125}},
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
-17.125, 15.25, 143, -27, 103, 101},
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
8.125, -99.5, 13.6875, -11.6875, 33}},
};
// Layer 0, *K*V Head 0
const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{-4.51717567, 6.93118095, 6.48003578, 9.12825584, 2.38755274, 11.8121576,
1.65376127, 5.04456615, -7.19549274, 2.57609844, 3.55331731, -3.48494458,
-8.90498638, 9.66047478, -0.379868984, 6.37043715, -2.24351144},
{0.152208567, 3.14520073, -8.35154343, 5.44226503, -6.74000502,
-1.43484437, -4.72092056, -9.48932, -6.12409401, -1.55352509, -3.90701318,
2.12124252, 3.93649936, -8.09877586, -3.30277514, -0.898857355,
1.76684189}},
{{4.378829, 5.05565643, -7.63948059, -5.74608946, 2.90109587, 0.155819178,
4.56115055, 1.37885749, 1.48427355, -1.07145202, 2.82399392, -1.20864201,
3.05434561, -2.65185618, -0.0731391, -8.2279253, 7.63228416},
{-0.702698231, 1.49563932, 6.42149782, -6.68306589, 1.85317755,
-7.70267582, 2.07357907, -7.60303402, -0.514724255, 0.308567047,
5.99250412, -4.67359257, -3.49322176, -2.62086344, -3.18411255,
2.04027057, -4.29057407}},
{{-1.20844436, 4.14724302, 6.04515219, 8.7753458, -0.975198627, 0.564640105,
5.39941597, 4.64036179, 0.366614938, 3.48258138, -0.470701456, 15.2267399,
4.63302803, 9.12662697, -5.89148045, 2.25731587, 5.24449492},
{4.57078934, -4.60315752, -3.3364439, 1.29875994, -3.40833569, -6.95262,
-6.39040232, -6.60212612, 6.63269806, -0.815209687, -5.0346446,
-4.13564968, 8.25674057, -6.0910182, -8.21130085, -8.91020393,
10.6188011}},
{{0.602011144, 2.22505236, 3.62411499, -4.07026958, 12.8036356, 3.76139069,
6.99502087, 7.02500725, -2.51568675, 4.2489934, 0.00210827589,
-1.43267739, -2.10394144, -0.0506809056, -1.54883039, 4.3740139,
-1.61869526},
{-6.37204599, -3.34989691, 2.10935307, 4.23634195, 5.79134035, 13.502944,
-2.19158888, -1.55771351, -1.22244942, 3.36499929, -2.11375904,
-4.5448761, 1.0611912, -2.47849369, -0.212709218, 0.363292456,
7.91467094}},
{{-8.85739231, -4.08585882, -0.618261, 6.52911091, 5.14922285, 7.6869874,
0.750387549, -0.812200725, 2.7509625, 6.29693508, -1.77248931, 5.68896484,
-6.9369607, -4.61359406, 0.184977874, -1.27769828, -2.1619854},
{-8.2555, 2.84032059, -1.03791106, 2.07648611, -4.94546843, 1.76888537,
-1.75901175, 11.2628574, 1.41086221, -3.58669901, -2.85925198, 2.29133463,
1.55509436, -0.0553357825, -10.0363655, 1.94261, -2.95691729}},
{{0.919141412, 1.97533965, -11.3202848, -3.3137629, -4.7161727, 5.07012081,
1.76256621, 8.20588207, 6.05700159, -3.89765406, -1.13639557, -1.32326794,
-3.01544905, -0.585309267, 2.60637712, 2.83708405, -3.39202118},
{9.11918, 2.11261511, -5.87290621, 11.6033278, -4.66597795, -7.13774204,
-9.10563755, -2.48294282, 3.35282946, -3.75122213, 0.404774547,
-9.11625195, 4.85711479, 1.43184578, 1.47673059, -4.75093, -3.45323014}},
{{4.17705393, -4.95192289, -10.5068378, 3.90004015, -3.51306129, 5.38068056,
0.901511431, 11.222868, 2.67285442, 9.18779, 5.61346769, 3.06534624,
-3.78898215, 0.767340839, 15.8207836, -4.14079094, -4.63177109},
{3.61795235, -7.00262165, 2.08284521, -6.70515728, 1.93205631, 2.84467721,
3.94591737, -6.18882942, -1.78465152, -9.39100933, -10.8780289,
6.32468653, 6.53142738, -3.30765963, 2.89132166, 4.53347206, 1.89792418}},
{{-0.361971855, -1.57735932, 5.07296801, -1.55669761, -1.44996238,
7.29838896, 5.23075104, -0.512441278, -3.59834242, 2.38584423, 6.48518324,
-1.48220074, -2.4264791, 10.7237988, 5.64735842, 5.6251297, -7.04244423},
{-0.795628309, 7.30230665, -1.71035647, -16.6999454, 3.05102086,
-4.9243927, 4.28508186, -0.694577456, 6.58464718, 4.40330124, 3.3250041,
1.90579033, -6.29048729, 2.55308104, -4.9746747, -0.681708, -5.98152351}},
{{2.57555652, -3.5651083, 0.784440041, -4.7043705, 2.37520599, -3.62385964,
-3.48913693, -7.28049421, -5.48726082, 1.95519221, 7.25192928, 3.07074118,
-11.9897156, 5.92244673, 5.07564354, 0.162699938, -6.00809956},
{5.56260443, -5.7683115, 1.26402235, -17.507719, 4.18873024, -3.20694613,
-4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766,
-7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187,
2.60931611}},
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925,
-4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064,
0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337},
{-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355,
-1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147,
-5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052,
-7.43491}},
};
// Layer 0, K*V* Head 0
const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423,
-1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332,
-4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026},
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801,
-10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463,
0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}},
{{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635,
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508,
-10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929},
{-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228,
3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248,
-7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251,
3.30500841}},
{{-3.34342527, 6.03099537, 6.335958, 0.993818045, 0.905343294, 6.93058586,
3.9635396, 10.8044815, 7.8620863, -10.1157322, -3.92666101, -0.183003783,
-5.27309418, -1.45110512, -8.96734, -2.63866425, 2.19913912},
{16.416317, -1.62025332, 2.3161006, 3.32571959, -1.79581594, -10.2925539,
-5.86338425, -6.36642933, 9.18872166, 5.95524168, 6.38640785, 8.23832,
-6.57342291, -14.2017632, 1.10925388, 4.27255058, -2.65661311}},
{{6.58254147, -6.96165133, -4.97437, -2.33467388, 5.83671236, -0.794236898,
-2.03117108, -3.93387103, -5.96872902, 5.83316422, 3.01795, -4.05260706,
-4.39556885, 3.24399853, 10.1573639, 4.71967888, 0.274738848},
{7.13243389, -8.04649162, 2.53055143, 2.0771277, -0.667295456, -13.0285645,
0.960428238, -2.11983275, 8.18105602, -6.72609901, -5.46944714,
0.204244614, 0.0900330544, 8.86620903, 4.63697529, 3.19756651,
2.99392676}},
{{9.52539158, -4.3840766, -6.94514465, -2.75913763, -10.8364506,
-3.95606327, 2.43603897, -5.78482246, -0.801304817, 8.23436832,
-7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822,
16.0471916},
{6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875,
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392,
1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}},
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125,
1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826,
7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216},
{7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912,
0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329,
3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}},
{{-2.48950672, -8.55112743, 8.04663277, -5.77116871, -0.637019753,
-7.65882111, -7.49037457, 3.8041625, -3.57038307, 9.37715435, -6.42604256,
1.62610793, -1.54000568, 2.52110147, 5.30775261, -4.10454893,
-4.96251774},
{-2.95554614, -5.18210888, 1.00015664, -4.03864431, -7.14954519,
5.99929142, 5.86350155, 2.03810191, -4.23009968, 9.39885902, -5.68198299,
2.72845244, 11.7133255, 0.838779449, -13.2235403, 2.94607735,
-2.7902379}},
{{2.86876941, -0.836064458, -0.374509573, -0.277966499, 3.20654631,
-3.68510771, -7.76134634, 2.23905277, -8.35530376, 5.25071716,
-1.38490796, -2.93542218, 0.509032726, -3.57361269, -2.82580233,
-4.49954033, 2.91235542},
{-4.37938213, 4.78577232, 2.03453469, 5.48564529, -1.05589461, -1.65940428,
4.0130887, 5.26074123, 4.67537832, 0.791350365, 6.3880868, 2.50402451,
7.6603322, -3.16343474, -2.71949649, 4.61576128, 1.3817997}},
{{0.289200783, 7.06031752, -1.15099299, -5.29136801, -1.343642, -8.36283112,
4.13158274, -1.93137062, 3.16199875, 2.21854591, 2.18270063, 0.77002573,
6.90393353, -0.644045949, -5.62211609, -1.09085155, 1.07821059},
{-3.04716778, -2.52233481, -5.99031925, 2.80152273, 0.340899587,
0.667474508, -2.39674735, 8.83768654, -5.45613146, -1.55994594, -2.216362,
1.49354, -4.27255821, -9.05310917, 5.90691471, -1.29772806, -8.50278}},
{{-3.1383903, -7.71573353, 3.38072681, 6.07642221, -2.39587545, -7.84178352,
-1.60108304, -8.6121521, -5.151721, 4.17612457, -2.86532378, 1.64645958,
-0.37970829, -4.34561253, -0.454322815, 0.331385136, -5.74550819},
{4.77026033, -5.51171303, -7.38155365, -5.38462543, 2.95842505, 5.18372536,
0.521988213, 7.23966122, -4.90852165, 7.18465281, 2.99289083, 10.0519466,
-2.09695673, 7.34368706, -2.40495348, 3.61603308, 0.131510735}},
};
// Layer 0, QHead 0
const float kGoldenQ[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{-0.574401975, 0.370210886, -0.426894158, -0.543187439, -0.0266762674,
-0.177960411, -0.00839618221, 0.411925405, 0.536462784, 0.528389931,
-0.499812007, -0.123897657, -0.0170236826, 0.266041577, -0.0781469196,
-0.44081074, 0.185976267},
{0.270543516, -0.109283224, -0.58602041, -0.358663559, -0.393124342,
-0.0895933211, -0.632167816, 0.386703, 0.314152211, 0.0554139167,
0.0241559595, -0.194484815, 0.143893063, 0.103837147, -0.384245932,
-0.00418212265, 0.385817379}},
{{-0.0331106335, -0.100827977, 0.322449774, 0.225943685, -0.384854138,
-0.208085626, 0.0206767023, 0.287796348, -0.139513299, 0.255447835,
-0.0845065042, -0.0619940236, 0.477489054, 0.517492294, -0.0172665715,
-0.0302075297, 0.365989387},
{-0.0266781822, -0.453293771, 0.560033202, 0.105156079, -0.35259968,
0.711447716, -0.253611088, 0.0487165749, -0.086192511, -0.0338740349,
-0.655441046, 0.00413730741, -0.510472536, -0.0748229772, -0.29113093,
-0.0432077348, 0.09223634}},
{{-0.321974993, -0.466039479, 0.207254037, -0.126807183, -0.192775592,
-0.0953654051, 0.209789664, 0.405356169, -0.00627984107, -0.0590961352,
0.0907663852, -0.190793216, -0.730463982, 0.340142608, -0.295675993,
-0.165913597, -0.233714506},
{-0.345578939, 0.394073665, 0.299743414, -0.0075177839, -0.288939595,
0.127782941, -0.207550645, 0.0655022636, -0.705084503, -0.241842598,
0.333820701, 0.217911497, 0.29735288, 0.0147881694, -0.152306199,
-0.589594781, -0.373093933}},
{{0.216089666, 0.0918798149, 0.0560657382, -0.157523662, -0.00141695142,
0.51770103, 0.596379519, -0.271057904, 0.241035417, -0.275827706,
0.112851456, 0.026878573, -0.579843462, -0.5116328, 0.192026839,
0.125176072, 0.34234497},
{-0.0744233653, 0.180814236, 0.170143247, -0.337861449, -0.175804421,
0.213403732, -0.173699334, 0.109528325, -0.385727316, 0.109683953,
0.475667775, 0.253016889, 0.477347463, 0.111096457, 0.394625545,
0.0172286481, -0.357992649}},
{{-0.350524545, -0.142550975, -0.212269634, -0.0589753427, -0.434021264,
0.384472728, 0.445421219, -0.635599554, -0.246593416, 0.120986834,
0.623568773, -0.161932915, -0.702406883, 0.44038102, 0.268234134,
0.480264157, 0.103595078},
{-0.227436215, 0.357608706, -0.25339672, -0.0683218762, -0.179259315,
0.23657614, 0.559984326, 0.165754288, -0.0402980596, -0.101906747,
-0.278261065, -0.16327399, 0.235923961, -0.428657919, -0.290629387,
0.579215467, -0.0717103705}},
{{-0.246389642, -0.266164362, -0.0967710763, -0.4011603, 0.242542207,
0.0869855583, 0.20158039, 0.207793877, -0.0875666738, -0.242263764,
-0.0462955758, -0.617374003, 0.454443514, 0.207072973, -0.0235372931,
-0.0193868056, -0.660622239},
{0.703284621, 0.0382430181, 0.43997851, -0.858277559, 0.342218578,
0.414044619, 0.403636098, -0.579880178, -1.12243, -0.112913512,
0.629238605, -0.0285760984, -0.152203664, -0.088969171, -0.0681343,
0.476349175, 0.283238202}},
{{0.138267457, 0.483219147, 0.230450034, -0.568304598, 0.204461277,
-0.286731184, -0.416590065, -0.483460307, -0.561008453, 0.395195067,
0.104367018, -0.196090236, -0.324770749, -0.0881370157, -0.626873195,
0.0936089084, 0.262185335},
{0.282603383, 0.0723766163, -0.206548154, 0.561849833, 0.482716829,
0.135281503, -0.438841999, 0.472577304, -0.346201897, -0.0211652666,
-0.0905084163, -0.168639392, -0.154975936, -0.303443581, -0.41771856,
0.400717318, 0.426146686}},
{{-0.0537007451, -0.227346331, -0.2871463, 0.247746795, -0.0975416005,
-0.0123391449, 0.0612513907, -0.374673814, 0.283457696, 0.40945363,
0.137944818, -0.0119741419, 0.775918365, -0.308365196, 0.230615795,
-0.440364927, 0.218536288},
{0.0688965544, -0.149037778, -0.246169299, 0.0599289536, -0.456733435,
0.0808929354, 0.115154952, 0.0997388735, -0.408117741, 0.576600909,
-0.193775773, 0.0340575948, -0.29254055, 0.695465446, 0.373336494,
0.421431482, 0.00197479129}},
{{0.402076721, -0.118151993, 0.542394996, 0.0382412486, -0.614983976,
0.28617692, 0.318540633, -0.299300969, -0.177486539, 0.394140214,
0.0644133314, -0.0321308076, 0.671587527, -0.0173831787, -0.219400048,
-0.340277791, 0.5130288},
{0.105372488, -0.145784974, 0.0695323348, -0.106080391, -0.755512118,
0.975362539, -0.15056029, 0.58882606, -0.059625227, -0.810613,
-0.321623206, 0.193939567, 0.0340242684, -0.626081824, 0.109950632,
-0.141072854, 0.0177994221}},
{{0.243249148, 0.0904035419, -0.472183734, -0.176162, 0.314925164,
-0.191137731, 0.492265761, -0.0120046511, 0.824757636, 0.298175,
0.148151726, -0.0197859108, -0.64297086, 0.432318538, -0.555079758,
0.101636633, 0.155741245},
{0.0523641109, 0.224086404, 0.0143201668, 0.0090854, 0.304901183,
-0.391372293, 0.267655343, 0.117368169, 0.645064473, 0.336050332,
-0.282133281, -0.231817603, 0.376230389, -0.575031936, -0.628365576,
0.484799922, 0.0824087635}},
};
void RunAttentionTest(AttentionImpl attention_impl) {
TestState state;
TestModelState model_state(state);
TestAttentionState attention_state(state, model_state, kNumTokens,
kQBatchSize, attention_impl);
GemmaAttention(attention_state.tokens.size(), 0, model_state.layer,
attention_state.attention, *attention_state.qbatch, state.env,
AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16));
CompareAttSumsWithGolden(attention_state.attention, kGoldenAttSums);
CompareKVCacheWithGolden(model_state.config,
hwy::Span<KVCache>(attention_state.kv_caches.data(),
attention_state.kv_caches.size()),
/*layer=*/0, /*kv_head=*/0, kGoldenK, kGoldenV);
CompareQVecsWithGolden(model_state.config, attention_state.attention,
/*q_head=*/0, kGoldenQ);
}
void TestGemmaAttentionOld() { RunAttentionTest(AttentionImpl::kOld); }
void TestGemmaAttentionFlash() { RunAttentionTest(AttentionImpl::kFlash); }
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(AttentionTest);
HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionOld);
HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionFlash);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.spin = gcpp::Tristate::kFalse; threading_args.spin = gcpp::Tristate::kFalse;
LoaderArgs loader(tokenizer_path, weights_path); threading_args.spin = gcpp::Tristate::kFalse;
LogDebug("LoaderArgs created"); GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
// Initialize cached args // Initialize cached args
LogDebug("Initializing inference args"); LogDebug("Initializing inference args");
InferenceArgs inference_args; args.inference.max_generated_tokens = max_generated_tokens;
inference_args.Init(); args.inference.temperature = 0.7f;
inference_args.max_generated_tokens = max_generated_tokens; args.inference.top_k = 1;
inference_args.temperature = 0.7f; args.inference.deterministic = false;
inference_args.top_k = 1;
inference_args.deterministic = false;
ss.str(""); ss.str("");
ss << "Inference args initialized with max_tokens: " << max_generated_tokens ss << "Inference args initialized with max_tokens: " << max_generated_tokens
<< ", temperature: " << inference_args.temperature << ", temperature: " << args.inference.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: " << ", top_k: " << args.inference.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false"); << (args.inference.deterministic ? "true" : "false");
LogDebug(ss.str().c_str()); LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args, return new GemmaContext(args, max_generated_tokens);
max_generated_tokens);
} }
GemmaContext::GemmaContext(const LoaderArgs& loader, GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
const InferenceArgs& inference_args, : args(args),
const ThreadingArgs& threading_args, ctx(args.threading),
int max_generated_tokens)
: inference_args(inference_args),
threading_args(threading_args),
ctx(threading_args),
matmul_env(ctx), matmul_env(ctx),
active_conversation_name("default"), active_conversation_name("default"),
model(loader, inference_args, matmul_env.ctx) { model(args, matmul_env.ctx) {
std::stringstream ss; std::stringstream ss;
LogDebug("Creating initial ConversationData"); LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared // Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>( active_conversation = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
LogDebug( LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]"); "Storing initial ConversationData in conversation_cache[\"default\"]");
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// set up runtime config // set up runtime config
TimingInfo timing_info = {}; TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.stream_token = stream_token, RuntimeConfig runtime_config = {.stream_token = stream_token,
.use_spinning = threading_args.spin}; .use_spinning = args.threading.spin};
inference_args.CopyTo(runtime_config); args.inference.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
const ModelConfig& model_config = model.Config(); const ModelConfig& model_config = model.Config();
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
timing_info); timing_info);
// prepare for next turn // prepare for next turn
if (!inference_args.multiturn || if (!args.inference.multiturn ||
model_config.wrapping == PromptWrapping::PALIGEMMA) { model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently), // If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position. // reset the *active* conversation's position.

View File

@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext { class GemmaContext {
private: private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, GemmaContext(const GemmaArgs& args, int max_generated_tokens);
const ThreadingArgs& threading_args, int max_generated_tokens);
public: public:
static GemmaContext* Create(const char* tokenizer_path, static GemmaContext* Create(const char* tokenizer_path,
@ -81,37 +80,37 @@ class GemmaContext {
// Set max generated tokens // Set max generated tokens
void SetMaxGeneratedTokens(size_t value) { void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value; args.inference.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value"); LogDebug("Setting max_generated_tokens to configured value");
} }
// Set multiturn flag (0 = disabled, 1 = enabled) // Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) { void SetMultiturn(int value) {
inference_args.multiturn = value; args.inference.multiturn = value;
LogDebug("Setting multiturn to configured value"); LogDebug("Setting multiturn to configured value");
} }
// Set temperature for token generation // Set temperature for token generation
void SetTemperature(float value) { void SetTemperature(float value) {
inference_args.temperature = value; args.inference.temperature = value;
LogDebug("Setting temperature to configured value"); LogDebug("Setting temperature to configured value");
} }
// Set top_k parameter for sampling // Set top_k parameter for sampling
void SetTopK(int value) { void SetTopK(int value) {
inference_args.top_k = value; args.inference.top_k = value;
LogDebug("Setting top_k to configured value"); LogDebug("Setting top_k to configured value");
} }
// Set deterministic flag // Set deterministic flag
void SetDeterministic(bool value) { void SetDeterministic(bool value) {
inference_args.deterministic = value; args.inference.deterministic = value;
LogDebug("Setting deterministic flag to configured value"); LogDebug("Setting deterministic flag to configured value");
} }
// Set prefill_tbatch_size // Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) { void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value; args.inference.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value"); LogDebug("Setting prefill_tbatch_size to configured value");
} }
@ -175,7 +174,7 @@ class GemmaContext {
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>( active_conversation->kv_cache = std::make_unique<KVCache>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {
@ -193,7 +192,7 @@ class GemmaContext {
LogDebug("Creating new conversation"); LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared // Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>( conversation_cache[name] = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator); model.Config(), args.inference, ctx.allocator);
return true; return true;
} }
@ -274,8 +273,7 @@ class GemmaContext {
std::vector<int> token_buffer; std::vector<int> token_buffer;
// Cached args (remain global for the context) // Cached args (remain global for the context)
InferenceArgs inference_args; GemmaArgs args;
ThreadingArgs threading_args;
ThreadingContext ctx; ThreadingContext ctx;
MatMulEnv matmul_env; MatMulEnv matmul_env;

View File

@ -22,8 +22,8 @@
#include <vector> #include <vector>
#include "compression/types.h" // Type #include "compression/types.h" // Type
#include "io/fields.h" // IFields #include "io/fields.h" // IFields
#include "io/io.h" // Path #include "io/io.h" // Path
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() {
config.display_name = "Gemma3_1B"; config.display_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B; config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
config.model_dim = 1152; config.model_dim = 1152;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() {
config.display_name = "Gemma3_4B"; config.display_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() {
config.display_name = "Gemma3_12B"; config.display_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() {
config.display_name = "Gemma3_27B"; config.display_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) {
} }
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
if (IsPaliGemma(model)) { const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping;
// For models with a fixed wrapping mode, ignore user override.
if (config_wrapping == PromptWrapping::PALIGEMMA ||
config_wrapping == PromptWrapping::GEMMA_VLM) {
if (wrapping != Tristate::kDefault) { if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); HWY_WARN("Ignoring unnecessary --wrapping for model %s.",
ModelPrefix(model));
} }
return PromptWrapping::PALIGEMMA; return config_wrapping;
} }
if (IsVLM(model)) {
if (wrapping != Tristate::kDefault) { // For other models, default to IT unless --wrapping=0 is passed.
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
}
return PromptWrapping::GEMMA_VLM;
}
// Default to IT unless --wrapping=0.
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
: PromptWrapping::GEMMA_IT; : PromptWrapping::GEMMA_IT;
} }
@ -674,7 +678,9 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
return Model::GEMMA3_270M; return Model::GEMMA3_270M;
case 26: case 26:
if (layer_types & kDeducedViT) return Model::GEMMA3_1B; if (layer_types & (kDeducedViT|kDeducedKqNorm)) {
return Model::GEMMA3_1B;
}
return Model::GEMMA2_2B; return Model::GEMMA2_2B;
case 27: case 27:
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448 return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
@ -706,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
} }
} }
AttentionImpl GetAttentionImpl(const std::string& impl) {
if (impl == "old") return AttentionImpl::kOld;
if (impl == "flash") return AttentionImpl::kFlash;
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
return AttentionImpl::kOld;
}
} // namespace gcpp } // namespace gcpp

View File

@ -80,6 +80,38 @@ static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
} }
enum class AttentionImpl {
kOld,
kFlash,
kSentinel,
};
AttentionImpl GetAttentionImpl(const std::string& impl);
/*
* Returns a bitmask of flags to pass to attention functions based on the
* attention implementation selected.
*
* If `hwy_native_dot_bf16` is true, the function will use the old attention
* implementation, ignoring `impl`.
*
* `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16
* macro is not available outside of highway instrumented translation units and
* cannot be made accessible from .h files.
*/
static inline int AttentionImplToFlags(AttentionImpl impl,
int hwy_native_dot_bf16) {
if (hwy_native_dot_bf16) return kAttentionUseOld;
switch (impl) {
case AttentionImpl::kOld:
return kAttentionUseOld;
case AttentionImpl::kFlash:
default:
return 0;
}
}
// Post attention and ffw normalization type. // Post attention and ffw normalization type.
enum class PostNormType { enum class PostNormType {
None, None,
@ -184,13 +216,6 @@ enum class Model {
// in Specifier and thus does not change. // in Specifier and thus does not change.
const char* ModelPrefix(Model model); const char* ModelPrefix(Model model);
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
static inline bool IsVLM(Model model) {
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
}
static inline bool IsPaliGemma(Model model) { static inline bool IsPaliGemma(Model model) {
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 || if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
model == Model::PALIGEMMA2_10B_224 || model == Model::PALIGEMMA2_10B_224 ||
@ -280,7 +305,7 @@ struct LayerConfig : public IFields {
uint32_t kv_heads = 0; uint32_t kv_heads = 0;
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
bool ff_biases = false; bool ff_biases = false;
bool optimized_gating = true; // for Gemma3 bool optimized_gating = true; // for Gemma3
PostNormType post_norm = PostNormType::None; PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma; LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu; ActivationType activation = ActivationType::Gelu;
@ -383,6 +408,8 @@ struct ModelConfig : public IFields {
internal.VisitFields(visitor); internal.VisitFields(visitor);
visitor(use_global_timescale);
// Append new fields here, then update `python/configs.cc`. // Append new fields here, then update `python/configs.cc`.
} }
@ -481,6 +508,7 @@ struct ModelConfig : public IFields {
std::vector<std::string> scale_base_names; std::vector<std::string> scale_base_names;
InternalModelConfig internal; InternalModelConfig internal;
bool use_global_timescale = false; // for Gemma 3
}; };
// Returns the sub-config for the ViT model of the PaliGemma model. // Returns the sub-config for the ViT model of the PaliGemma model.
@ -489,6 +517,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes { enum DeducedLayerTypes {
kDeducedViT = 2, kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeducedKqNorm = 8,
}; };
// layer_types is one or more of `DeducedLayerTypes`. // layer_types is one or more of `DeducedLayerTypes`.

View File

@ -17,12 +17,18 @@
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <iostream>
#include <limits> #include <limits>
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h" #include "util/zones.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -30,7 +36,6 @@
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim #include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -59,7 +64,7 @@ static constexpr size_t kNFx8HTileSize = 8;
// q has shape [batch, qbatch][head, qkv_dim]. // q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum // q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV. // possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t, static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) { const size_t qbatch_size, ThreadingContext& ctx) {
// Group floats by the number of floats in a cache line. // Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
@ -70,12 +75,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
for (size_t lane = 0; lane < kNF; ++lane) { for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane; size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break; if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row); BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) { for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) { for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) { for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] = qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]; hwy::ConvertScalarTo<BF16>(
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
} }
} }
} }
@ -84,45 +90,48 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
{ {
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
// Better than kFlat. // Better than kFlat.
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func); /*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
} }
} }
// Updates q in place for RMSNorm and positional encoding. // Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<KV_t>& q, const size_t layer_idx, MatPtrT<float>& q,
const LayerWeightsPtrs& layer, const MatPtr& query_norm_scale,
const AttentionActivations& activations, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
size_t qi = div_qbatch.Remainder(task); size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task); size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) { for (size_t h = 0; h < layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * batch_idx + qi; const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate // Find the token position in the query and calculate
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx; constexpr size_t offset = 0; // placeholder, do not remove
float* HWY_RESTRICT q_row = const size_t pos =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim; qbatch.Pos(qi) + batch_idx + offset;
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx, worker); layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker, PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
pos, query_scale); query_scale);
} }
}; };
{ {
// kHierarchical is not worth the extra sync overhead because the tasks are // kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight. // very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding, /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
func); func);
} }
@ -152,15 +161,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
// Calculates the complete attention outputs for a single row of q. // Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos, void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivationsPtrs& activations,
const AttentionActivations& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols()); // TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) { if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap); m = cap * std::tanh(m / cap);
@ -170,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos); const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols()); float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d, SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out); v.Row(pos_mod), v.Cols(), att_out);
} }
@ -180,25 +194,27 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep. // the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q, const size_t k_pos, const MatPtrT<BF16>& q,
const MatPtrT<KV_t>& k) { const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
hn::TFromD<DF> results[hn::MaxLanes(df)]; hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) { for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
} }
return hn::LoadU(df, results); return hn::LoadU(df, results);
} }
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single // Returns an NF Q rows by 8 K rows tile of Q.K dot products.
// precision.
// This is the result of NF rows of Q against 8 K timesteps, with positions // This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride. // consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize; constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df); sum0 = hn::Zero(df);
sum1 = hn::Zero(df); sum1 = hn::Zero(df);
@ -209,11 +225,16 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
sum6 = hn::Zero(df); sum6 = hn::Zero(df);
sum7 = hn::Zero(df); sum7 = hn::Zero(df);
const float* HWY_RESTRICT k_row[kHTileSize]; const float* HWY_RESTRICT k_row[kHTileSize];
for (int i = 0; i < kHTileSize; ++i) { for (size_t i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]); k_row[i] = k.Row(k_pos[i]);
} }
const hn::Rebind<BF16, DF> dbfh;
using VBF = hn::Vec<decltype(dbfh)>;
for (size_t i = 0; i < k.Cols(); ++i) { for (size_t i = 0; i < k.Cols(); ++i) {
VF q_vec = hn::Load(df, q); const VBF q_vec_bf = hn::Load(dbfh, q);
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
VF k_0 = hn::Set(df, k_row[0][i]); VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0); sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]); VF k_1 = hn::Set(df, k_row[1][i]);
@ -266,31 +287,30 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
void TileFlashAttention( void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k, const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t min_last_pos, const size_t max_last_pos, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const MatPtrT<KV_t>& v, const size_t layer_idx, const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, const size_t worker) {
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize; constexpr size_t kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
using DI = hn::ScalableTag<uint32_t>; using DI = hn::ScalableTag<uint32_t>;
const DI di; const DI di;
using VI = hn::Vec<DI>; using VI = hn::Vec<DI>;
const int kVTileSize = hn::Lanes(df); const size_t kVTileSize = hn::Lanes(df);
for (int i = 0; i < kVTileSize; ++i) { for (size_t i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0])); v.Cols() * sizeof(att_out.Row(0)[0]));
} }
VI lasts = hn::LoadU(di, last_pos); VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f); VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df); VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0); const BF16* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride(); const size_t qT_stride = qT.Stride();
size_t position = start_pos; size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) { while (position + kHTileSize - 1 <= min_last_pos) {
@ -299,8 +319,7 @@ void TileFlashAttention(
k_pos[i] = activations.div_seq_len.Remainder(position + i); k_pos[i] = activations.div_seq_len.Remainder(position + i);
} }
VF x0, x1, x2, x3, x4, x5, x6, x7; VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
x7);
if (activations.config.att_cap > 0.0f) { if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap); VF cap = hn::Set(df, activations.config.att_cap);
@ -374,7 +393,7 @@ void TileFlashAttention(
// This is the result of 4 rows of Q against NF K timesteps, with positions // This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF]. // given by k_offsets[0..NF].
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) { VF& sum2, VF& sum3) {
@ -389,13 +408,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets); VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) { for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[q_offsets[0] + i]); VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
sum0 = hn::MulAdd(q_0, k_vec, sum0); sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[q_offsets[1] + i]); VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
sum1 = hn::MulAdd(q_1, k_vec, sum1); sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[q_offsets[2] + i]); VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
sum2 = hn::MulAdd(q_2, k_vec, sum2); sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[q_offsets[3] + i]); VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
sum3 = hn::MulAdd(q_3, k_vec, sum3); sum3 = hn::MulAdd(q_3, k_vec, sum3);
} }
} }
@ -410,23 +429,202 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float scale = old_d * std::exp(old_max - m); float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale; old_d = hn::ReduceSum(df, x) + scale;
old_max = m; old_max = m;
float one_over_d = 1.0f / old_d; if (old_d > 0.0f) {
scale *= one_over_d; const float one_over_d = 1.0f / old_d;
x = hn::Mul(x, hn::Set(df, one_over_d)); scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
} else {
scale = 0.0f;
x = hn::Zero(df);
}
return scale; return scale;
} }
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // Reduces each of x and stores in following lanes of max (tested with float32)
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, template <class DF, typename T = hn::TFromD<DF>,
// max_last_pos]. class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
void TileFlashAttention4( class VF = hn::Vec<DF>, typename F>
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
F reducer) {
const DF4 df4;
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
HWY_ALIGN T x_transposed[4 * kMaxLanes];
hn::StoreInterleaved4<DF>(x_0, x_1, x_2, x_3, df, x_transposed);
VF4 result = hn::Load(df4, x_transposed);
for (int i = 1; i < kLanes; ++i) {
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
}
return result;
}
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
float* HWY_RESTRICT scales) {
using DF4 = hn::CappedTag<float, 4>;
const DF4 df4;
using VF4 = hn::Vec<DF4>;
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
if constexpr (kNumQueries >= 2) {
max_1 = hn::Max(x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
max_2 = hn::Max(x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
max_3 = hn::Max(x_3_p0, x_3_p1);
}
if constexpr (kNumQueries == 1) {
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
} else {
new_max = Reduce4(df, max_0, max_1, max_2, max_3,
[](auto a, auto b) { return hn::Max(a, b); });
}
if (att_cap > 0.0f) {
VF4 cap = hn::Set(df4, att_cap);
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
}
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
// TODO figure out what was wrong with broadcasts and change to that.
HWY_ALIGN float tmp_max[4];
hn::Store(new_max, df4, tmp_max);
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, tmp_max[0]);
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0));
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, tmp_max[1]);
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, tmp_max[2]);
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, tmp_max[3]);
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
hn::StoreU(new_max, df4, old_max);
VF4 x_sum = hn::Zero(df4);
if constexpr (kNumQueries == 1) {
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
} else {
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) { return hn::Add(a, b); });
}
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
const VF4 zero4 = hn::Zero(df4);
const VF4 one_over_d =
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
float tmp_one_over_d[4];
hn::Store(one_over_d, df4, tmp_one_over_d);
hn::Store(old_d_vf, df4, old_d);
scale = hn::Mul(scale, one_over_d);
hn::Store(scale, df4, scales);
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
} else {
x_0_p0 = zero;
x_0_p1 = zero;
}
if constexpr (kNumQueries >= 2) {
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
} else {
x_1_p0 = zero;
x_1_p1 = zero;
}
}
if constexpr (kNumQueries >= 3) {
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
} else {
x_2_p0 = zero;
x_2_p1 = zero;
}
}
if constexpr (kNumQueries >= 4) {
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
} else {
x_3_p0 = zero;
x_3_p1 = zero;
}
}
}
// Implements flash attention for a strip of 4 query vectors.
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
// by NF timesteps in K for efficiency while timesteps between `min_last_pos +
// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
// values within the strip.
// (*) Actually, it only iterates through
// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
// computation can, for obvious reasons, only process an integer number of
// tiles.
//
// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
// vectors to be processed in this tile.
// @param k Key matrix [seq_len, qkv_dim] from KV cache.
// @param start_pos The first token position in the KV cache to attend to.
// @param last_pos An array of 4 indices giving the last token position
// (inclusive) that each of the 4 queries may attend to.
// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
// position can be processed efficiently in batches.
// @param max_last_pos The maximum value in `last_pos`. Timesteps between
// `min_last_pos + 1` and this position are processed individually to
// respect each query's `last_pos` limit.
// @param v Value matrix [seq_len, qkv_dim] from KV cache.
// @param layer_idx The index of the current transformer layer.
// @param activations Attention configurations and buffers.
// @param att_out Output buffer for attention results.
// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
// vectors.
// @param ctx Threading context.
// @param worker Worker thread index.
Tile4FlashState TileFlashAttention4(
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
ThreadingContext& ctx, const size_t worker) { const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
@ -440,14 +638,7 @@ void TileFlashAttention4(
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0])); v.Cols() * sizeof(att_out.Row(0)[0]));
} }
float old_m0 = -std::numeric_limits<float>::max() / 2.0f; Tile4FlashState state;
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
float old_d0 = 0.0f;
float old_d1 = 0.0f;
float old_d2 = 0.0f;
float old_d3 = 0.0f;
size_t position = start_pos; size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) { while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF]; int32_t k_offsets[kMaxNF];
@ -467,46 +658,62 @@ void TileFlashAttention4(
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
} }
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); state.row_states[0].d);
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); state.row_states[1].d);
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
state.row_states[2].d);
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
state.row_states[3].d);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols()); out_offsets, v.Cols());
position += kHTileSize; position += kHTileSize;
} }
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
while (position <= max_last_pos) { while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position); size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) { if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count. // Past the last position, x0 doesn't count.
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]); att_out.Row(0) + out_offsets[0]);
} }
if (position <= last_pos[1]) { if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count. // Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]); att_out.Row(0) + out_offsets[1]);
} }
if (position <= last_pos[2]) { if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count. // Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]); att_out.Row(0) + out_offsets[2]);
} }
if (position <= last_pos[3]) { if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count. // Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]); att_out.Row(0) + out_offsets[3]);
} }
++position; ++position;
} }
return state;
} }
// Rounds n to a number that can be used as the number of Q rows in a tile // Rounds n to a number that can be used as the number of Q rows in a tile
@ -589,14 +796,25 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the // grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency. // 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism, void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const MatPtr& query_norm_scale,
AttentionActivations& activations, QBatch& qbatch, AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
layer, activations, ctx); query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; // Compress q to q_bf.
ParallelFor(
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) {
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(
df, activations.q.Row(row), activations.q.Cols(), tls,
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
});
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
@ -684,14 +902,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
size_t last = pos; size_t last = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi); const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last) { if (prefix_end > 0 && prefix_end - 1 > last) {
// last_pos in QDotK and WeightedSumV is inclusive. // last_pos in `TileFlashAttention` is inclusive.
last = prefix_end - 1; last = prefix_end - 1;
} }
last_pos[offset] = last; last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last); min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last); max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] = q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); activations.q_bf.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0); activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups; const size_t kv_index = head / kHeadGroups;
@ -719,9 +937,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// To avoid duplicating the code to setup K and V, the call to // To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even though it // TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once. // handles all rows in the task at once.
StridedView<float> qT = StridedView<BF16> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize, StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride()); activations.q_T.Stride());
if (kVTileSize == kNF) { if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q // We can still use TileFlashAttention even if we didn't transpose Q
// above. The condition used for transposing Q above is more general // above. The condition used for transposing Q above is more general
@ -730,14 +948,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// kNFx8HTileSize. In this case, qT is never used. Some tasks might // kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition // use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used. // is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k, TileFlashAttention(activations.q_bf, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) { } else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k, TileFlashAttention4(activations.q_bf, q_offsets, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
} else { } else {
HWY_UNREACHABLE; HWY_UNREACHABLE;
@ -745,8 +963,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
break; break;
} else { } else {
SingleFlashAttention(start_positions[offset], last_pos[offset], SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v, activations.q_bf.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations, layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset], activations.att_out.Row(0) + out_offsets[offset],
ctx, worker); ctx, worker);
} }

View File

@ -20,36 +20,48 @@
#include <stddef.h> #include <stddef.h>
#include <cstdint>
#include "gemma/flash_structs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "hwy/highway.h" #include "hwy/highway.h"
namespace gcpp { namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ void RMSNormAndPositionalEncoding( \
MatPtrT<KV_t>& q, size_t layer_idx, \ size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const LayerWeightsPtrs& layer, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivations& activations, \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
ThreadingContext& ctx); \ \
\ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ const BF16* HWY_RESTRICT q, \
const float* HWY_RESTRICT q, \ const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ size_t layer_idx, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ const AttentionActivationsPtrs& activations, \
const AttentionActivations& activations, \ float* HWY_RESTRICT att_out, \
float* HWY_RESTRICT att_out, \ ThreadingContext& ctx, size_t worker); \
ThreadingContext& ctx, size_t worker); \ \
\ Tile4FlashState TileFlashAttention4( \
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
size_t total_tasks, size_t target_parallelism); \ const MatPtrT<KV_t>& k, size_t start_pos, \
\ const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
AttentionActivations& activations, QBatch& qbatch, \ MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx); \ ThreadingContext& ctx, const size_t worker); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ \
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the // Function declarations for each SIMD target. Allows direct call from the

View File

@ -14,6 +14,8 @@
// limitations under the License. // limitations under the License.
#include <cstring> #include <cstring>
#include <iostream>
#include <limits>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
@ -24,6 +26,7 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/test_util.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -109,11 +112,14 @@ void TestFlashAttention(size_t target_parallelism) {
const LayerConfig& layer_config = config.layer_configs[0]; const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
InferenceArgs inference_args; InferenceArgs inference_args;
inference_args.attention_impl = "flash";
RuntimeConfig runtime_config; RuntimeConfig runtime_config;
inference_args.CopyTo(runtime_config);
KVCache kv_cache(config, inference_args, ctx.allocator); KVCache kv_cache(config, inference_args, ctx.allocator);
MatMulEnv env(ctx); MatMulEnv env(ctx);
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(runtime_config, config,
kv_cache.SeqLen(), env.ctx, env.row_ptrs); runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
env.ctx, env.row_ptrs);
std::vector<int> tokens(kOuter); std::vector<int> tokens(kOuter);
std::iota(tokens.begin(), tokens.end(), 1); std::iota(tokens.begin(), tokens.end(), 1);
PromptTokens prompt(tokens); PromptTokens prompt(tokens);
@ -122,8 +128,10 @@ void TestFlashAttention(size_t target_parallelism) {
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter; const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention(config, layer_config, batch_size, kOuter, AttentionActivations attention_storage(config, layer_config, batch_size,
ctx.allocator, row_ptrs); kOuter, runtime_config, ctx.allocator,
row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner); ASSERT_EQ(qkv_dim, kInner);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
@ -145,7 +153,8 @@ void TestFlashAttention(size_t target_parallelism) {
SetMat(h + layer_config.heads * 2, v); SetMat(h + layer_config.heads * 2, v);
} }
SetMat(1, attention.q); SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
qbatch, ctx);
// Copy the output to saved_att to allow for comparison. // Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q); SetMat(1, attention.q);
@ -158,8 +167,8 @@ void TestFlashAttention(size_t target_parallelism) {
total_tasks, target_parallelism); total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize); target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
qbatch, ctx); attention, qbatch, ctx);
AssertClose(attention.att_out, *saved_att); AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults(); ctx.profiler.PrintResults();
} }

31
gemma/flash_structs.h Normal file
View File

@ -0,0 +1,31 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <limits>
namespace gcpp {
// State for computing softmax in a streaming ("online") manner,
// avoiding large intermediate values by subtracting the running maximum.
// For a sequence x_1, ..., x_n:
// m_i = max(m_{i-1}, x_i)
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
// softmax_i = exp(x_i - m_i) / d_i
struct OnlineSoftmaxState {
// Maximum logit value encountered so far.
float max = -std::numeric_limits<float>::max() / 2.0f;
// Sum of exponentials scaled by exp(-max).
float d = 0.0f;
};
static constexpr size_t kVTileSize4 = 4;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_

View File

@ -20,6 +20,7 @@
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/tensor_stats.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/mat.h" #include "util/mat.h"
@ -70,7 +71,7 @@ template <class Mat>
void ActivationBatched( void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx, ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { Parallelism parallelism = Parallelism::kFlat) {
using T = typename Mat::T; using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
@ -115,7 +116,7 @@ template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched( HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { Parallelism parallelism = Parallelism::kFlat) {
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out,
env.ctx);
#if GEMMA_FUSED_FFN #if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) { StridedViewBF C2, size_t worker) {
@ -179,8 +183,31 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
env.ctx); env.ctx);
#endif #endif
activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx);
// Hidden layer -> output layer. // Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx);
}
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
activations.att_sums);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -18,12 +18,16 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
#include "gemma/tensor_stats.h"
#include "util/zones.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
// clang-format off // clang-format off
@ -35,7 +39,7 @@
// After highway.h // After highway.h
#include "gemma/attention.h" // includes highway.h #include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"
#include "gemma/vit.h" // includes highway.h #include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE #ifndef GEMMA_CC_ONCE
#define GEMMA_CC_ONCE #define GEMMA_CC_ONCE
@ -73,10 +77,12 @@ namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, const size_t num_tokens, void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) { Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention. // TODO: remove flag to enable FlashAttention.
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, GemmaAttention(
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); num_tokens, layer_idx, layer, activations.attention, qbatch, env,
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
} }
} }
@ -355,6 +361,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f); token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt; qbatch.MutablePos(qi) = pos_in_prompt;
} else {
// This prevents the kv cache of eos_id to be written to last prefilled
// token.
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
} }
qbatch.PrevToken(qi) = token; qbatch.PrevToken(qi) = token;
@ -426,7 +436,7 @@ static void SampleAndStream(const ModelConfig& config,
timing_info.NotifyGenerated(non_eos.Count()); timing_info.NotifyGenerated(non_eos.Count());
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, Parallelism::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, Callers::kSampleAndStream, /*cluster_idx=*/0, Callers::kSampleAndStream,
[&](size_t qi, size_t worker) { [&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return; if (!non_eos.Get(qi)) return;
@ -484,12 +494,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
}; };
} }
// Decode: generates one continuation token for each query in `qbatch`. static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config,
const RuntimeConfig& runtime_config, const WeightsPtrs& weights,
const AesCtrEngine& engine, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch,
Activations& activations, QBatch& qbatch, MatMulEnv& env, MatMulEnv& env, TimingInfo& timing_info) {
TimingInfo& timing_info) {
size_t max_prompt_size = 0; size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true; bool all_prefix_end_are_zero = true;
size_t total_prefill_tokens = 0; // only for throughput stats. size_t total_prefill_tokens = 0; // only for throughput stats.
@ -511,14 +520,14 @@ static void GenerateT(const ModelConfig& config,
// We use a single divisor, so all sequence lengths must be the same. // We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
if (max_prompt_size >= seq_len) { if (max_prompt_size > seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", HWY_ABORT(
max_prompt_size); "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
"that.",
max_prompt_size, seq_len);
} }
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
hwy::BitSet4096<> non_eos; // indexed by qi hwy::BitSet4096<> non_eos; // indexed by qi
timing_info.prefill_start = hwy::platform::Now(); timing_info.prefill_start = hwy::platform::Now();
@ -536,18 +545,6 @@ static void GenerateT(const ModelConfig& config,
timing_info.NotifyPrefill(total_prefill_tokens); timing_info.NotifyPrefill(total_prefill_tokens);
// queries_pos have been incremented by Prefill. // queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do
// not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
config, runtime_config, qbatch, update_pos, non_eos);
}
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;
if (max_prompt_size + max_gen_steps > seq_len) { if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
@ -555,6 +552,55 @@ static void GenerateT(const ModelConfig& config,
max_gen_steps = seq_len - max_prompt_size; max_gen_steps = seq_len - max_prompt_size;
} }
return max_gen_steps;
}
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do
// not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
config, runtime_config, qbatch, update_pos, non_eos);
}
void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a,
ThreadingContext& ctx) {
const size_t layer_idx = layer.layer_idx;
a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx,
kTensorStatsIsWeight);
a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx,
kTensorStatsIsWeight);
a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight);
}
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
for (const LayerWeightsPtrs& layer : weights.c_layers) {
SetWeightStats(layer, activations, env.ctx);
}
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
hwy::BitSet4096<> non_eos; // indexed by qi
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi);
}
const SampleFunc sample_token = const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx); ChooseSampleFunc(runtime_config, engine, env.ctx);
@ -567,14 +613,66 @@ static void GenerateT(const ModelConfig& config,
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }
// Same as GenerateT, but uses ContinuousQBatch.
static void GenerateTWithContinuousBatching(
const ModelConfig& config, const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) {
const size_t qbatch_size = runtime_config.decode_qbatch_size;
QBatch qbatch(0, qbatch_size, all_queries);
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
hwy::BitSet4096<> non_eos;
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
size_t query_inserted = 0;
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
// Continue if qi slot is still processing.
if (non_eos.Get(qi)) continue;
// Collect the kv_cache from the qi slot in the qbatch to the
// available_kv_caches_ in the prefill_batch.
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
// Prefill if no available prefilled queries to insert.
if (prefill_batch.ShouldPrefill()) {
prefill_batch.SetupNextBatchForPrefill();
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
prefill_batch, env, timing_info);
activations.SetBatchSize(qbatch.Size());
}
// Get the next query to insert to the generate batch.
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
if (qi_to_insert) {
qbatch.Insert(qi_to_insert.value(), qi);
query_inserted++;
non_eos.Set(qi);
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
qi);
}
}
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights, const AesCtrEngine& engine, const WeightsPtrs& weights,
KVCache& kv_cache, MatMulEnv& env, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) { TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(runtime_config, config,
kv_cache.SeqLen(), env.ctx, env.row_ptrs); runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
env.ctx, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end, AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1)); hwy::Span<KVCache>(&kv_cache, 1));
@ -592,16 +690,21 @@ void GenerateBatchT(const ModelConfig& config,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size); runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, Activations activations(runtime_config, config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.ctx, all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs); env.row_ptrs);
for (size_t start = 0; start < all_queries.NumQueries(); if (runtime_config.use_continuous_batching) {
start += runtime_config.decode_qbatch_size) { GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); activations, all_queries, env, timing_info);
// Generate a batch of one token for each of `qbatch.Size()` queries. } else {
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, for (size_t start = 0; start < all_queries.NumQueries();
timing_info); start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, engine, weights, activations, qbatch,
env, timing_info);
}
} }
} }
@ -617,8 +720,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len; const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim); num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, Activations prefill_activations(runtime_config, vit_config, num_tokens,
env.row_ptrs); num_tokens, env.ctx, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env); prefill_activations, env);
@ -635,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT); HWY_EXPORT(GenerateImageTokensT);
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
ThreadingContext& ctx) : reader_(args.loader.weights),
: reader_(loader.weights), model_(reader_, args.loader.tokenizer, args.loader.wrapping),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()), weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model), chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference), inference_(args.inference),
aes_ctr_engine_(inference.deterministic) { aes_ctr_engine_(args.inference.deterministic) {
// Negligible CPU time in the ctor body (except ReadFromBlobs). // Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
mat_owners_, ctx); args.inference, mat_owners_, ctx);
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive. // Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
reader_.CloseFile(); reader_.CloseFile();
} }
@ -698,5 +800,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
: QBatch(0, max_size, queries) {
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
if (!queries_[i].kv_cache.IsEmpty()) {
// Put the kv_cache to the available_kv_caches_ instead; leaving the
// kv_cache in the queries_ is very confusing. This simplifies the logic
// of kv_cache management.
available_kv_caches_.push_back(queries_[i].kv_cache);
queries_[i].kv_cache = KVCachePtr();
}
}
}
bool ContinuousQBatch::ShouldPrefill() const {
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
return no_available_to_insert && more_queries_to_prefill;
}
void ContinuousQBatch::SetupNextBatchForPrefill() {
start_ = next_to_prefill_;
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
query_idx_.clear();
query_idx_.reserve(size_);
for (size_t i = 0; i < size_; ++i) {
const size_t next_query_idx = start_ + i;
query_idx_.push_back(next_query_idx);
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
available_kv_caches_.pop_back();
}
next_to_prefill_ += size_;
}
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
if (next_to_insert_ == next_to_prefill_) {
return std::nullopt;
}
next_to_insert_++;
return next_to_insert_ - 1;
}
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
const int query_to_collect = from.QueryIdx(0);
// Only collect if the query to collect is not the same as the next query to
// insert. This happens at the beginning of each Generate call.
if (query_to_collect != next_to_insert_) {
// Only clear the KV cache if there are more queries to insert; Otherwise
// we get a crash because Transformer will still access that KV cache.
if (next_to_insert_ < queries_.NumQueries()) {
available_kv_caches_.push_back(from.KV(0));
ZeroInit(from.KV(0).kv_cache);
from.KV(0) = KVCachePtr();
}
}
}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -18,6 +18,7 @@
#include <stdio.h> #include <stdio.h>
#include <optional>
#include <vector> #include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
@ -26,6 +27,7 @@
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "gemma/model_store.h" #include "gemma/model_store.h"
#include "gemma/query.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "io/blob_store.h" #include "io/blob_store.h"
#include "io/io.h" // Path #include "io/io.h" // Path
@ -38,132 +40,28 @@
namespace gcpp { namespace gcpp {
struct PerQuery { // Used for continuous batching.
PromptTokens prompt; class ContinuousQBatch : public QBatch {
// Position in the KV cache: initially zero for the first turn, or when
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
size_t mutable_pos;
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
// which might differ from `prompt.size() - 1` for prefix-LM.
size_t initial_pos;
// Zero for causal attention, or the end of the prefix for prefix-LM style
// attention in Paligemma.
size_t prefix_end;
KVCache& kv_cache;
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
int prev_token = 0;
};
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
struct AllQueries {
AllQueries() = default;
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompt,
.mutable_pos = pos,
.initial_pos = pos,
.prefix_end = prefix_end,
.kv_cache = kv_caches[i],
});
}
}
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i],
});
}
}
void Reserve(size_t size) { per_query_.reserve(size); }
void Append(const PerQuery& query) { per_query_.push_back(query); }
size_t NumQueries() const { return per_query_.size(); }
PerQuery& operator[](size_t query_idx) {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
const PerQuery& operator[](size_t query_idx) const {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
private:
std::vector<PerQuery> per_query_;
};
// View into AllQueries: either a batch of queries, or a single query for use
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
// reference to AllQueries.
class QBatch {
public: public:
QBatch(size_t start, size_t max_size, AllQueries& queries) ContinuousQBatch(size_t max_size, AllQueries& queries);
: start_(start),
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}
// Returns a single-query view starting at `qi` relative to this batch. // Whether we should prefill the next batch, i.e. next_to_insert_ ==
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } // next_to_prefill_.
bool ShouldPrefill() const;
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. // Setup the query_idx_ to point to the next group of queries to prefill.
size_t Size() const { return size_; } void SetupNextBatchForPrefill();
// Returns index for use with `AllQueries` and `BatchStreamToken`. // Get the next query to insert to the generate batch.
size_t QueryIdx(size_t qi) const { std::optional<size_t> GetNextToInsert();
HWY_DASSERT(qi < size_);
return start_ + qi;
}
// Accessor functions to bridge the previous SoA and current AoS layout. // Collect the kv_cache from QBatch to available_kv_caches_.
const PromptTokens& Prompt(size_t qi) const { void MaybeReleaseKV(const QBatch& from);
return queries_[QueryIdx(qi)].prompt;
}
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
size_t InitialPos(size_t qi) const {
return queries_[QueryIdx(qi)].initial_pos;
}
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private: public:
size_t start_; int next_to_prefill_ = 0;
size_t max_size_; int next_to_insert_ = 0;
AllQueries& queries_; std::vector<KVCachePtr> available_kv_caches_;
size_t size_;
}; };
struct TimingInfo { struct TimingInfo {
@ -232,11 +130,16 @@ struct TimingInfo {
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. // separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
class Gemma { class Gemma {
public: public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
// `ctx` is only used to read tensors and not stored. Calls to `Generate*` // `ctx` is only used to read tensors and not stored. Calls to `Generate*`
// may reference the same, or other `ThreadingContext` via `MatMulEnv`. // may reference the same, or other `ThreadingContext` via `MatMulEnv`.
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
// Deprecated prior interface for backwards compatibility.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx); ThreadingContext& ctx)
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
~Gemma(); ~Gemma();
const ModelConfig& Config() const { return model_.Config(); } const ModelConfig& Config() const { return model_.Config(); }

View File

@ -24,10 +24,12 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include "io/io.h" // Path #include "gemma/configs.h"
#include "util/args.h" #include "io/io.h" // Path
#include "util/args.h" // IWYU pragma: export
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -35,7 +37,9 @@
namespace gcpp { namespace gcpp {
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
LoaderArgs(const std::string& tokenizer_path, LoaderArgs(const std::string& tokenizer_path,
const std::string& weights_path) { const std::string& weights_path) {
Init(); // Init sets to defaults, so assignments must come after Init(). Init(); // Init sets to defaults, so assignments must come after Init().
@ -139,6 +143,9 @@ struct RuntimeConfig {
int verbosity; // Controls verbosity of printed messages. int verbosity; // Controls verbosity of printed messages.
// Which attention implementation to use.
AttentionImpl attention_impl = AttentionImpl::kFlash;
// Functions operating on the generated tokens. // Functions operating on the generated tokens.
StreamFunc stream_token; StreamFunc stream_token;
BatchStreamFunc batch_stream_token; BatchStreamFunc batch_stream_token;
@ -159,10 +166,15 @@ struct RuntimeConfig {
// default decision is likely sufficient because it is based on whether // default decision is likely sufficient because it is based on whether
// threads are successfully pinned. // threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault; mutable Tristate use_spinning = Tristate::kDefault;
// Whether to use continuous batching.
bool use_continuous_batching = false;
}; };
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
InferenceArgs() { Init(); }; InferenceArgs() { Init(); };
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); } bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
@ -187,6 +199,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
// For prompts longer than the Linux terminal's 4K line edit buffer. // For prompts longer than the Linux terminal's 4K line edit buffer.
Path prompt_file; Path prompt_file;
std::string eot_line; std::string eot_line;
std::string attention_impl;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
@ -240,6 +253,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"before the line where only the given string appears.\n Default = " "before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.", "When a newline is encountered, that signals the end of the turn.",
2); 2);
visitor(attention_impl, "attention_impl", std::string("flash"),
"Attention implementation to use. See configs.cc for options.", 2);
} }
void CopyTo(RuntimeConfig& runtime_config) const { void CopyTo(RuntimeConfig& runtime_config) const {
@ -261,36 +276,39 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.temperature = temperature; runtime_config.temperature = temperature;
runtime_config.top_k = top_k; runtime_config.top_k = top_k;
runtime_config.attention_impl = GetAttentionImpl(attention_impl);
} }
}; };
struct ClientArgs : public ArgsBase<ClientArgs> { // Bundles all args required to construct a `GemmaEnv` or the equivalent.
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } struct GemmaArgs {
ClientArgs() { Init(); }; // For callers that do not parse command line args.
GemmaArgs(const LoaderArgs& loader,
const ThreadingArgs& threading = ThreadingArgs(),
const InferenceArgs& inference = InferenceArgs())
: loader(loader), threading(threading), inference(inference) {}
std::string host; GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
int port; : loader(argc, argv, consumed),
std::string api_key; threading(argc, argv, consumed),
std::string model; inference(argc, argv, consumed) {}
std::string prompt;
bool interactive;
template <class Visitor> void Help() {
void ForEach(const Visitor& visitor) { fprintf(stderr,
visitor(host, "host", std::string("localhost"), "To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"Server host (default: localhost)"); "With the single-file weights format, specify just --weights.\n"
visitor(port, "port", 8080, "\n*Model Loading Arguments*\n");
"Server port (default: 8080)"); loader.Help();
visitor(api_key, "api_key", std::string(""), fprintf(stderr, "\n*Threading Arguments*\n");
"Use public API with key (changes host to " threading.Help();
"generativelanguage.googleapis.com:443)"); fprintf(stderr, "\n*Inference Arguments*\n");
visitor(model, "model", std::string("gemma3-4b"), inference.Help();
"Model name to use (default: gemma3-4b)"); fprintf(stderr, "\n");
visitor(prompt, "prompt", std::string("Hello! How are you?"),
"Prompt for generation (default: 'Hello! How are you?')");
visitor(interactive, "interactive", false,
"Start interactive chat mode (0 = no, 1 = yes)");
} }
LoaderArgs loader;
ThreadingArgs threading;
InferenceArgs inference;
}; };
} // namespace gcpp } // namespace gcpp

74
gemma/gemma_args_test.cc Normal file
View File

@ -0,0 +1,74 @@
#include "gemma/gemma_args.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace gcpp {
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
ptrs.reserve(args.size());
for (const std::string& arg : args) {
ptrs.push_back(const_cast<char*>(arg.data()));
}
}
static void CheckAllConsumed(const std::vector<std::string>& args) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
}
static void CheckUnconsumed(const std::vector<std::string>& args,
size_t expected) {
std::vector<char*> ptrs;
FillPtrs(args, ptrs);
const int argc = static_cast<int>(args.size());
char** argv = const_cast<char**>(ptrs.data());
ConsumedArgs consumed(argc, argv);
GemmaArgs gemma_args(argc, argv, consumed);
ASSERT_EQ(expected, consumed.FirstUnconsumed());
}
// Note: do not use --help because that is not actually consumed; it is actually
// special-cased in `HasHelp`.
TEST(GemmaArgsTest, AllConsumedArgs) {
// Single arg
CheckAllConsumed({"gemma", "--weights=x"});
// Two args, one with =
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
// Two args, one with extra value
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
// Two args with values
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
}
TEST(GemmaArgsTest, UnconsumedArgs) {
// Single unconsumed arg
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
// Single unconsumed arg, no --
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
// Single unconsumed arg after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
// Single unconsumed arg before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
// Single unconsumed arg with = after valid arg
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
// Single unconsumed arg with = before valid arg
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
// Multiple unconsumed args
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
// Multiple unconsumed args with valid arg between
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
}
} // namespace gcpp

View File

@ -16,6 +16,7 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include <stddef.h> #include <stddef.h>
#include <vector>
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
@ -50,8 +51,16 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_); KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache); CopyMat(kv_cache, copy.kv_cache);
return copy; return copy;
} }
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(kv_caches[i].ToPtr());
}
return ptrs;
}
} // namespace gcpp } // namespace gcpp

View File

@ -18,7 +18,10 @@
#include <stddef.h> #include <stddef.h>
#include "gemma/configs.h" // ModelConfig #include <optional>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs #include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" #include "util/mat.h"
@ -27,18 +30,33 @@ namespace gcpp {
using KV_t = float; using KV_t = float;
// A non-owning view of a KVCache.
struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const;
MatPtrT<KV_t> kv_cache;
};
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args, KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator); const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of // Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit. // copy ctor to make the cost explicit.
KVCache Copy(); KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
KVCachePtr ToPtr() {
return KVCachePtr{
.kv_cache = kv_cache,
};
}
private: private:
const Allocator& allocator_; const Allocator& allocator_;
@ -46,6 +64,13 @@ struct KVCache {
KVCache(const Extents2D& kv_extents, const Allocator& allocator); KVCache(const Extents2D& kv_extents, const Allocator& allocator);
}; };
inline size_t KVCachePtr::SeqLen() const {
return kv_cache.Rows();
}
// Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_

43
gemma/kv_cache_test.cc Normal file
View File

@ -0,0 +1,43 @@
#include "gemma/kv_cache.h"
#include <cstddef>
#include <vector>
#include "gtest/gtest.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
namespace {
TEST(KVCacheTest, KVCacheToPtrs) {
ModelConfig model_config;
model_config.max_seq_len = 1024;
model_config.num_layers = 2;
for (int i = 0; i < model_config.num_layers; ++i) {
model_config.layer_configs.push_back(LayerConfig());
model_config.layer_configs.back().kv_heads = 4;
model_config.layer_configs.back().qkv_dim = 256;
}
InferenceArgs inference_args;
inference_args.seq_len = 1024;
RuntimeConfig runtime_config;
runtime_config.attention_impl = AttentionImpl::kFlash;
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
std::vector<KVCache> caches;
caches.emplace_back(model_config, inference_args, runtime_config,
ctx.allocator);
inference_args.seq_len = 512;
caches.emplace_back(model_config, inference_args, runtime_config,
ctx.allocator);
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
ASSERT_EQ(ptrs.size(), 2);
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
}
} // namespace
} // namespace gcpp

View File

@ -221,6 +221,8 @@ static size_t DeduceNumLayers(const KeyVec& keys) {
// This works with or without type prefixes because it searches for substrings. // This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const BlobReader& reader) { static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0; int layer_types = 0;
bool has_key_norm = false;
bool has_query_norm = false;
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx]; const std::string& key = reader.Keys()[key_idx];
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
@ -232,6 +234,15 @@ static int DeduceLayerTypes(const BlobReader& reader) {
layer_types |= kDeduced448; layer_types |= kDeduced448;
} }
} }
if (key.find("key_norm") != std::string::npos) { // NOLINT
has_key_norm = true;
}
if (key.find("query_norm") != std::string::npos) { // NOLINT
has_query_norm = true;
}
}
if (has_key_norm && has_query_norm) {
layer_types |= kDeducedKqNorm;
} }
return layer_types; return layer_types;
} }

186
gemma/query.h Normal file
View File

@ -0,0 +1,186 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_
#include <vector>
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "util/basics.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
namespace gcpp {
struct PerQuery {
PromptTokens prompt;
// Position in the KV cache: initially zero for the first turn, or when
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
size_t mutable_pos;
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
// which might differ from `prompt.size() - 1` for prefix-LM.
size_t initial_pos;
// Zero for causal attention, or the end of the prefix for prefix-LM style
// attention in Paligemma.
size_t prefix_end;
KVCachePtr kv_cache;
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
int prev_token = 0;
};
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
struct AllQueries {
AllQueries() = default;
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCachePtr>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompt,
.mutable_pos = pos,
.initial_pos = pos,
.prefix_end = prefix_end,
.kv_cache = kv_caches[i],
});
}
}
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches)
: AllQueries(prompt, pos, prefix_end,
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCachePtr>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(prompts.size());
for (size_t i = 0; i < prompts.size(); ++i) {
HWY_ASSERT(kv_caches.size() == 0 ||
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
});
}
}
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
prefix_end) {}
void Reserve(size_t size) { per_query_.reserve(size); }
void Append(const PerQuery& query) { per_query_.push_back(query); }
size_t NumQueries() const { return per_query_.size(); }
PerQuery& operator[](size_t query_idx) {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
const PerQuery& operator[](size_t query_idx) const {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
private:
std::vector<PerQuery> per_query_;
};
// View into AllQueries: either a batch of queries, or a single query for use
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
// reference to AllQueries.
class QBatch {
public:
QBatch(size_t start, size_t max_size, AllQueries& queries)
: start_(start),
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
query_idx_.reserve(size_);
for (size_t i = 0; i < size_; ++i) {
query_idx_.push_back(start_ + i);
}
}
// Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; }
// Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_);
return query_idx_[qi];
}
// Accessor functions to bridge the previous SoA and current AoS layout.
const PromptTokens& Prompt(size_t qi) const {
return queries_[QueryIdx(qi)].prompt;
}
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
size_t InitialPos(size_t qi) const {
return queries_[QueryIdx(qi)].initial_pos;
}
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
// let query_idx_[to] point to the from in the queries_; this is only used if
// the slot in the QBatch is less than the number of queries.
void Insert(size_t from, size_t to) {
if (from == to) return;
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
// Conceptually, insert from.query to location to.
query_idx_[to] = from;
}
protected:
size_t start_;
size_t max_size_;
AllQueries& queries_;
std::vector<size_t> query_idx_;
size_t size_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_

View File

@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
} }
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { MatMulEnv& env) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
const InferenceArgs& inference = args.inference;
const int verbosity = inference.verbosity;
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0; size_t prompt_size = 0;
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
HWY_ASSERT(image.ReadPPM(inference.image_file.path)); HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = config.vit_config.image_size; const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = inference.verbosity, RuntimeConfig runtime_config = {.verbosity = verbosity,
.use_spinning = threading.spin}; .use_spinning = args.threading.spin};
double image_tokens_start = hwy::platform::Now(); double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens, env); image_tokens, env);
if (inference.verbosity >= 1) { if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start; double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr, fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n", "\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -129,11 +131,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
float) { float) {
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(pos == abs_pos); HWY_ASSERT(pos == abs_pos);
++abs_pos; ++abs_pos;
std::string token_text;
if (!gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
if (token == -2) return true; // Gemma 3 ViT?
HWY_WARN("Failed to decode token %d.", token);
}
const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn; ++tokens_generated_this_turn;
@ -185,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
TimingInfo timing_info = {.verbosity = inference.verbosity}; TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.verbosity = inference.verbosity, RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.batch_stream_token = batch_stream_token, .batch_stream_token = batch_stream_token,
.use_spinning = threading.spin}; .use_spinning = args.threading.spin};
inference.CopyTo(runtime_config); inference.CopyTo(runtime_config);
std::vector<int> prompt; std::vector<int> prompt;
size_t prefix_end = 0; size_t prefix_end = 0;
@ -248,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
} }
void Run(const LoaderArgs& loader, const ThreadingArgs& threading, void Run(const GemmaArgs& args) {
const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
ThreadingContext ctx(threading); ThreadingContext ctx(args.threading);
MatMulEnv env(ctx); MatMulEnv env(ctx);
const InferenceArgs& inference = args.inference;
if (inference.verbosity >= 3) env.print_best = true; if (inference.verbosity >= 3) env.print_best = true;
const Gemma gemma(loader, inference, ctx); const Gemma gemma(args, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator); KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
@ -283,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
if (inference.IsInteractive()) { if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.Config(), ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
} }
ReplGemma(threading, inference, gemma, kv_cache, env); ReplGemma(args, gemma, kv_cache, env);
} }
} // namespace gcpp } // namespace gcpp
@ -298,17 +303,24 @@ int main(int argc, char** argv) {
gcpp::InternalInit(); gcpp::InternalInit();
{ {
// Negligible CPU time. // Negligible CPU time.
gcpp::LoaderArgs loader(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::ThreadingArgs threading(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner; std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, threading, inference); fprintf(stderr,
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n\n");
args.Help();
return 0; return 0;
} }
gcpp::Run(loader, threading, inference); // After `HasHelp` so that we print --help even if unconsumed args remain.
consumed.AbortIfUnconsumed();
gcpp::Run(args);
} }
PROFILER_PRINT_RESULTS(); // Must call outside the zone above. PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0; return 0;

View File

@ -1,3 +1,18 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/tensor_info.h" #include "gemma/tensor_info.h"
#include <stddef.h> #include <stddef.h>

View File

@ -1,3 +1,18 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_

205
gemma/tensor_stats.cc Normal file
View File

@ -0,0 +1,205 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/tensor_stats.h"
#if GCPP_TENSOR_STATS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <atomic>
#include <cmath>
#include <memory>
#include "io/io.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/profiler.h" // StringTable
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
float Correlation(const float* x, size_t num) {
double sum = 0.0;
for (size_t i = 0; i < num; ++i) {
sum += x[i];
}
const double mean = sum / static_cast<double>(num);
double numerator = 0.0;
double sum_sq_current = 0.0;
double sum_sq_next = 0.0;
for (size_t i = 0; i < num - 1; ++i) {
const double diff_current = static_cast<double>(x[i]) - mean;
const double diff_next = static_cast<double>(x[i + 1]) - mean;
numerator += diff_current * diff_next;
sum_sq_current += diff_current * diff_current;
sum_sq_next += diff_next * diff_next;
}
if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f;
const double denominator = std::sqrt(sum_sq_current * sum_sq_next);
const float corr = static_cast<float>(numerator / denominator);
HWY_DASSERT(-1.0f <= corr && corr <= 1.0f);
return corr;
}
// Only write tensor data the first time it is encountered per layer. This is
// a concurrent string+layer -> flag map which avoids std::mutex (incompatible
// with fibers). We use a string table to index into per-layer atomic flags.
static bool ShouldWrite(const char* name, size_t layer_idx) {
constexpr size_t kMaxNames = 128;
constexpr size_t kMaxLayers = 128;
HWY_DASSERT(layer_idx < kMaxLayers);
static hwy::StringTable<kMaxNames> s_table;
const size_t name_idx = s_table.Add(name);
static std::atomic_flag flags[kMaxNames * kMaxLayers] = {};
return !flags[name_idx * kMaxLayers + layer_idx].test_and_set(
std::memory_order_acq_rel);
}
std::unique_ptr<File> MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased,
const Path& tensor_output) {
if (tensor_output.Empty()) return nullptr;
if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr;
char path[1024];
snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin",
tensor_output.path.c_str(), type_erased.Name(), layer_idx,
type_erased.Rows(), type_erased.Cols(),
TypeName(type_erased.GetType()));
return OpenFileOrAbort(Path(path), "wb");
}
void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
size_t row_idx) {
if (!file) return;
const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes();
file->Write(type_erased.RowBytes(row_idx), bytes_per_row,
bytes_per_row * row_idx);
}
// First dispatch to the type, then parallel over rows, then vectorized
// decompress and Notify for each value.
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
const MatPtr& type_erased, ThreadingContext& ctx, int flags,
size_t cluster_idx, Parallelism parallelism) {
std::unique_ptr<File> file =
MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output);
if ((flags & kTensorStatsIsWeight) && layer_idx != 0) {
// Still compute stats, but remember not to print them.
stats.Get(layer_idx, 0).DoNotPrint();
}
CallUpcasted(&type_erased, [&](const auto* mat) {
const size_t cols = mat->Cols();
ParallelFor(
parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats,
[&](size_t row_idx, size_t global_idx) {
GCPP_ZONE(ctx, global_idx, Zones::kGenStats);
auto* HWY_RESTRICT row = mat->Row(row_idx);
MaybeWriteRow(file, type_erased, row_idx);
using Packed = hwy::RemoveCvRef<decltype(*row)>;
PackedSpan<Packed> packed(const_cast<Packed*>(row), cols);
TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx);
my_stats.NotifyCond(ConditionNumber(row, cols));
namespace hn = hwy::HWY_NAMESPACE;
hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
size_t packed_ofs = 0;
if (cols >= 2 * NF) {
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
VF v0, v1;
Decompress2(df, packed, packed_ofs, v0, v1);
hn::Store(v0, df, buf);
hn::Store(v1, df, buf + NF);
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
const float min = hn::ReduceMin(df, min_mag);
if (min != 0.0f) { // Avoid division by zero.
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
}
for (size_t i = 0; i < 2 * NF; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
}
}
// Zero to two vectors remaining.
for (; packed_ofs < cols; packed_ofs += NF) {
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
// Skip NotifyGroup for this partial group.
for (size_t i = 0; i < remaining; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
my_stats.NotifyCorr(Correlation(buf, remaining));
}
});
});
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(UpdateStatsT);
// Must reside in .cc file so that we can #include compress-inl.h.
void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased,
ThreadingContext& ctx, int flags, size_t cluster_idx,
Parallelism parallelism) {
// Ignore empty tensors.
if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) {
return;
}
HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags,
cluster_idx, parallelism);
}
} // namespace gcpp
#endif // HWY_ONCE
#endif // GCPP_TENSOR_STATS

347
gemma/tensor_stats.h Normal file
View File

@ -0,0 +1,347 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "util/basics.h"
#include "hwy/base.h"
#ifndef GCPP_TENSOR_STATS
#define GCPP_TENSOR_STATS 0
#endif
#include "util/mat.h"
#include "util/threading_context.h"
#if GCPP_TENSOR_STATS
#include <cmath>
#include <vector>
#include "hwy/stats.h"
#endif // GCPP_TENSOR_STATS
namespace gcpp {
// For flags. Used to inhibit printing per-layer stats for weights.
HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1;
#if GCPP_TENSOR_STATS
HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192;
// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`.
// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify``
// to see the distribution of per-layer means.
struct TensorStatsAcrossLayers {
bool IsEmpty() const { return s_frobenius.Count() == 0; }
void Print() {
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str());
fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str());
fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str());
fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str());
fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str());
fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str());
fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str());
fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str());
fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str());
fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str());
if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) {
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
}
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
if (s_exp_subnormal.Min() != 0.0f) {
fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str());
}
if (s_big_cols.Count() != 0) {
fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str());
const size_t modal_col = b_big_cols.ModalBinIdx();
const size_t num_outlier_cols = b_big_cols.NumNonzero();
if (num_outlier_cols > 256) {
fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n",
b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col));
} else {
fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col,
b_big_cols.Bin(modal_col));
for (size_t i = 0; i < kStatsMaxCols; ++i) {
if (b_big_cols.Bin(i) > 2) {
fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i));
}
}
}
}
fprintf(stderr, "\n");
}
hwy::Stats s_frobenius;
hwy::Stats s_cond_min;
hwy::Stats s_cond_avg;
hwy::Stats s_cond_max;
hwy::Stats s_val_min;
hwy::Stats s_val_avg;
hwy::Stats s_val_kurt;
hwy::Stats s_mag_min;
hwy::Stats s_mag_avg;
hwy::Stats s_mag_max;
hwy::Stats s_corr_avg;
hwy::Stats s_corr_max;
hwy::Stats s_range_avg;
hwy::Stats s_exp_min;
hwy::Stats s_exp_max;
hwy::Stats s_exp_mode;
hwy::Stats s_exp_subnormal;
hwy::Stats s_big_cols; // total number of outlier cols
hwy::Bins<kStatsMaxCols> b_big_cols; // # layers with outlier per col
};
// Per-thread and layer.
class TensorStatsAccumulator {
public:
void Notify(float val, size_t row_idx, size_t col_idx) {
const double dval = static_cast<double>(val);
sum_sq_ += dval * dval;
s_val_.Notify(val);
const float mag = hwy::ScalarAbs(val);
if (HWY_UNLIKELY(mag >= 64.0f)) {
if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx);
if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx);
}
// Skip zero so we can see the lowest actual magnitude
if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag);
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(mag);
// Use biased exponent because Bins wants unsigned values.
const uint32_t biased_exp = binary32 >> 23;
HWY_DASSERT(biased_exp < 256); // already cleared sign bit
b_exp256_.Notify(biased_exp);
}
void DoNotPrint() { skip_.fetch_or(1); }
bool ShouldPrint() const { return skip_.load() == 0; }
// Vector code computed the min/max of a group (= two vectors); this is
// faster than doing it in `Notify`.
void NotifyGroup(float min, float max) {
s_group_min_.Notify(min);
s_group_max_.Notify(max);
// Caller ensures min != 0.
s_group_range_.Notify(max / min);
}
void NotifyCorr(float corr) { s_corr_.Notify(corr); }
void NotifyCond(double cond) { s_cond_.Notify(cond); }
void Assimilate(const TensorStatsAccumulator& other) {
skip_.fetch_or(other.skip_.load());
sum_sq_ += other.sum_sq_;
b_exp256_.Assimilate(other.b_exp256_);
b_big_row_.Assimilate(other.b_big_row_);
b_big_col_.Assimilate(other.b_big_col_);
s_val_.Assimilate(other.s_val_);
s_mag_.Assimilate(other.s_mag_);
s_corr_.Assimilate(other.s_corr_);
s_group_min_.Assimilate(other.s_group_min_);
s_group_max_.Assimilate(other.s_group_max_);
s_group_range_.Assimilate(other.s_group_range_);
}
// Called on the per-layer representative after reducing across threads.
void NotifyAcrossLayer(TensorStatsAcrossLayers& s) {
s.s_frobenius.Notify(std::sqrt(sum_sq_));
s.s_cond_min.Notify(s_cond_.Min());
s.s_cond_avg.Notify(s_cond_.Mean());
s.s_cond_max.Notify(s_cond_.Max());
s.s_val_min.Notify(s_val_.Min());
s.s_val_avg.Notify(s_val_.Mean());
s.s_val_kurt.Notify(s_val_.Kurtosis());
s.s_mag_min.Notify(s_mag_.Min());
s.s_mag_avg.Notify(s_mag_.Mean());
s.s_mag_max.Notify(s_mag_.Max());
s.s_corr_avg.Notify(s_corr_.Mean());
s.s_corr_max.Notify(s_corr_.Max());
s.s_range_avg.Notify(s_group_range_.Mean());
const uint32_t subnormals = b_exp256_.Bin(0);
// Prevent subnormals from hiding the min exponent.
b_exp256_.ResetBin(0);
s.s_exp_min.Notify(b_exp256_.FirstNonzero());
s.s_exp_max.Notify(b_exp256_.LastNonzero());
s.s_exp_mode.Notify(b_exp256_.ModalBinIdx());
s.s_exp_subnormal.Notify(subnormals);
const uint32_t num_outliers = b_big_col_.NumNonzero();
if (num_outliers != 0) {
s.s_big_cols.Notify(num_outliers);
// For each col, count the number of layers that have an outlier there.
for (size_t i = 0; i < kStatsMaxCols; ++i) {
if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i);
}
}
}
bool IsEmpty() const { return s_val_.Count() == 0; }
void PrintAll() {
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
b_exp256_.Print("exp");
PrintBinRanges(b_big_row_, "big row");
PrintBinRanges(b_big_col_, "big col");
fprintf(stderr, "\n");
}
private:
template <size_t N>
void PrintBinRanges(const hwy::Bins<N>& b, const char* name) {
uint64_t total = 0;
for (size_t i = 0; i < N; ++i) {
total += b.Bin(i);
}
if (total == 0) return;
// If all bins are at least 10% of a uniform distribution, print the range
// to vastly reduce the log size.
const size_t min = HWY_MAX(1, total / (N * 10));
size_t last = 0;
for (; last < N; ++last) {
if (b.Bin(last) < min) break;
}
if (last >= N / 2) {
// Also require all subsequent bins to be zero, otherwise we should
// print the outlier bins.
bool all_zero = true;
for (size_t i = last + 1; i < N; ++i) {
if (b.Bin(last) != 0) {
all_zero = false;
break;
}
}
if (all_zero) {
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
return;
}
}
b.Print(name, /*skip_zero=*/true);
}
double sum_sq_ = 0.0; // for Frobenius norm
hwy::Bins<256> b_exp256_; // exponent
hwy::Bins<kMaxBatchSize> b_big_row_;
hwy::Bins<kStatsMaxCols> b_big_col_;
hwy::Stats s_val_;
hwy::Stats s_mag_;
hwy::Stats s_cond_; // condition number
hwy::Stats s_corr_; // lag-1 autocorrelation
hwy::Stats s_group_min_;
hwy::Stats s_group_max_;
hwy::Stats s_group_range_;
std::atomic<int> skip_{0};
};
class TensorStats {
public:
TensorStats(size_t num_layers, size_t max_workers)
: num_layers_(num_layers),
max_workers_(max_workers),
acc_(num_layers * max_workers) {}
// Parallelized across rows. If `ctx.tensor_output` is not empty, writes
// tensor data to disk for offline analysis, once per tensor and layer.
void Notify(size_t layer_idx, const MatPtr& type_erased,
ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0,
Parallelism parallelism = Parallelism::kFlat);
// For use by `UpdateStatsT`.
TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) {
const size_t idx = layer_idx * max_workers_ + global_idx;
HWY_DASSERT(idx < acc_.size());
return acc_[idx];
}
void ReduceAndPrint(const char* prefix) {
for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) {
TensorStatsAccumulator& per_layer = Get(layer_idx, 0);
for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) {
per_layer.Assimilate(Get(layer_idx, global_idx));
}
if (per_layer.IsEmpty()) continue;
per_layer.NotifyAcrossLayer(across_layers_);
if (per_layer.ShouldPrint()) {
fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx);
per_layer.PrintAll();
}
}
if (!across_layers_.IsEmpty()) {
fprintf(stderr, "================= across layers %s\n", prefix);
across_layers_.Print();
}
}
private:
size_t num_layers_;
size_t max_workers_;
std::vector<TensorStatsAccumulator> acc_;
TensorStatsAcrossLayers across_layers_;
};
#else // GCPP_TENSOR_STATS
class TensorStats {
public:
TensorStats(size_t, size_t) {}
void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0,
Parallelism = Parallelism::kFlat) {}
void ReduceAndPrint(const char*) {}
};
#endif // GCPP_TENSOR_STATS
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_

View File

@ -78,13 +78,9 @@ class VitAttention {
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Shift Q, K, VT to MatStorageT. MatPtrT<float>& Q = activations_.attention.vit_Q;
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim), MatPtrT<float>& K = activations_.attention.vit_K;
env_.ctx.allocator, MatPadding::kPacked); MatPtrT<float>& C = activations_.attention.vit_C;
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
MatPadding::kPacked);
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
env_.ctx.allocator, MatPadding::kPacked);
// Initialize att_out to zero prior to head loop. // Initialize att_out to zero prior to head loop.
ZeroInit(activations_.attention.att_out); ZeroInit(activations_.attention.att_out);
@ -295,19 +291,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const size_t model_dim = model_config.vit_config.model_dim; const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width; const size_t patch_width = model_config.vit_config.patch_width;
const size_t num_tokens = model_config.vit_config.seq_len; const size_t num_tokens = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3; const size_t patch_area = patch_width * patch_width * 3;
const hwy::Divisor div_patch_dim(patch_width);
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
HWY_DASSERT(activations.x.Cols() == model_dim); HWY_DASSERT(activations.x.Cols() == model_dim);
(void)model_dim; (void)model_dim;
// img/embedding/kernel has original shape (14, 14, 3, 1152) // img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`. // Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size), MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
env.ctx.allocator, MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
for (size_t i = 0; i < num_tokens; ++i) { for (size_t i = 0; i < num_tokens; ++i) {
image.GetPatch(i, image_patches.Row(i)); image.GetPatch(i, div_patch_dim, image_patches.Row(i));
} }
CallMatMul(image_patches, weights.vit_img_embedding_kernel, CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);

View File

@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const size_t cluster_idx = 0; const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx); GetLayer(layer)->Fixup(mat_owners, ctx);
}); });
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx); VitLayer(layer)->Fixup(mat_owners, ctx);
}); });
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) { Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task]; TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) { const BlobReader& reader, ThreadingContext& ctx) {
// Especially TSAN is slow enough to warrant hierarchical parallelism. // Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD const Parallelism parallelism =
? ParallelismStrategy::kHierarchical HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
: ParallelismStrategy::kFlat; ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) { Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
const TensorToRead& tensor = tensors[task]; const TensorToRead& tensor = tensors[task];
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
ThreadingContext& ctx) { ThreadingContext& ctx) {
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
/*cluster_idx=*/0, Callers::kReadBatches, /*cluster_idx=*/0, Callers::kReadBatches,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);

View File

@ -96,7 +96,8 @@ struct LayerWeightsPtrs {
// other values for purposes of the KV cache. // other values for purposes of the KV cache.
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors) const TensorInfoRegistry& tensors)
: finder_(LayerSuffix(layer_idx), tensors), : layer_idx(layer_idx),
finder_(LayerSuffix(layer_idx), tensors),
qkv_einsum_w(finder_("qkv_ein")), qkv_einsum_w(finder_("qkv_ein")),
qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w1(finder_("qkv1_w")),
qkv_einsum_w2(finder_("qkv2_w")), qkv_einsum_w2(finder_("qkv2_w")),
@ -135,6 +136,7 @@ struct LayerWeightsPtrs {
} }
~LayerWeightsPtrs() = default; ~LayerWeightsPtrs() = default;
const size_t layer_idx;
const MatFinder finder_; const MatFinder finder_;
// Files either have qkv_einsum_w with 2 stacked matrices or separate // Files either have qkv_einsum_w with 2 stacked matrices or separate

View File

@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
ThreadingContext& ctx, size_t cluster_idx) { ThreadingContext& ctx, size_t cluster_idx) {
HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size());
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size()); HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes, reader.file().Read(ranges[i].offset, ranges[i].bytes,
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters()); ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest, ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
[&](const size_t task, size_t cluster_idx) { [&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx); task ? blobs1 : blobs2, ctx, cluster_idx);
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
std::atomic<size_t> blobs_equal{}; std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{}; std::atomic<size_t> blobs_diff{};
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
Callers::kTest, [&](size_t i, size_t /*thread*/) { Callers::kTest, [&](size_t i, size_t /*thread*/) {
const size_t mismatches = const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]); BlobDifferences(blobs1[i], blobs2[i], keys[i]);

View File

@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
static_cast<const uint8_t*>(data), writes); static_cast<const uint8_t*>(data), writes);
const ParallelismStrategy strategy = file_->IsAppendOnly() const Parallelism parallelism =
? ParallelismStrategy::kNone file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
: ParallelismStrategy::kFlat;
ParallelFor( ParallelFor(
strategy, writes.size(), ctx_, parallelism, writes.size(), ctx_,
/*cluster_idx=*/0, Callers::kBlobWriter, /*cluster_idx=*/0, Callers::kBlobWriter,
[this, &writes](uint64_t i, size_t /*thread*/) { [this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range; const BlobRange& range = writes[i].range;

View File

@ -131,7 +131,7 @@ class BlobWriter {
std::vector<size_t> blob_sizes_; std::vector<size_t> blob_sizes_;
ThreadingContext& ctx_; ThreadingContext& ctx_;
// Current offset in the file used for writing. // Current offset in the file used for writing.
int64_t curr_offset_ = 0; uint64_t curr_offset_ = 0;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0, Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](uint64_t i, size_t /*thread*/) { Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
std::to_string(i).c_str()); std::to_string(i).c_str());

View File

@ -20,6 +20,8 @@
#include <stdio.h> #include <stdio.h>
#include <limits> #include <limits>
#include <string>
#include <vector>
#include <type_traits> #include <type_traits>
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"

View File

@ -110,7 +110,8 @@ class FilePosix : public File {
HWY_WARN( HWY_WARN(
"Read failure at pos %zu within size %zu with offset %zu and " "Read failure at pos %zu within size %zu with offset %zu and "
"errno %d\n", "errno %d\n",
pos, size, offset, errno); static_cast<size_t>(pos), static_cast<size_t>(size),
static_cast<size_t>(offset), errno);
break; break;
} }
pos += bytes_read; pos += bytes_read;
@ -130,7 +131,8 @@ class FilePosix : public File {
HWY_WARN( HWY_WARN(
"Write failure at pos %zu within size %zu with offset %zu and " "Write failure at pos %zu within size %zu with offset %zu and "
"errno %d\n", "errno %d\n",
pos, size, offset, errno); static_cast<size_t>(pos), static_cast<size_t>(size),
static_cast<size_t>(offset), errno);
break; break;
} }
pos += bytes_written; pos += bytes_written;
@ -194,9 +196,9 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
namespace gcpp { namespace gcpp {
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) { std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
std::unique_ptr<File> file = OpenFileOrNull(filename, "r"); std::unique_ptr<File> file = OpenFileOrNull(filename, mode);
if (!file) { if (!file) {
HWY_ABORT("Failed to open %s", filename.path.c_str()); HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno);
} }
return file; return file;
} }
@ -234,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
return true; return true;
} }
void InternalInit() { int InternalInit() {
// currently unused, except for init list ordering in GemmaEnv.
return 0;
} }
uint64_t IOBatch::Read(const File& file) const { uint64_t IOBatch::Read(const File& file) const {

View File

@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
// No-op in open-source. Must be called at the beginning of a binary, before // No-op in open-source. Must be called at the beginning of a binary, before
// any I/O or flag usage. // any I/O or flag usage.
void InternalInit(); int InternalInit();
} // namespace gcpp } // namespace gcpp

View File

@ -23,7 +23,9 @@ namespace gcpp {
namespace { namespace {
struct WriterArgs : public ArgsBase<WriterArgs> { struct WriterArgs : public ArgsBase<WriterArgs> {
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
Path output_weights; Path output_weights;
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
} // namespace gcpp } // namespace gcpp
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::WriterArgs args(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
if (args.output_weights.Empty()) { gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::WriterArgs writer_args(argc, argv, consumed);
if (writer_args.output_weights.Empty()) {
HWY_ABORT("Missing --output_weights flag, a file for the model weights."); HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
} }
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(argc, argv); gcpp::GemmaEnv env(args);
env.GetGemma()->Save(args.output_weights, env.Env().ctx); env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
return 0; return 0;
} }

View File

@ -413,7 +413,8 @@ using DotKernelDefault =
template <class D, typename WT, typename VT> template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs, HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) { const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault()); return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
DotKernelDefault());
} }
// Adapter for two pointers, no bounds checking. // Adapter for two pointers, no bounds checking.

View File

@ -891,18 +891,6 @@ class DotStats {
hwy::Stats s_times[kVariants]; hwy::Stats s_times[kVariants];
}; };
// Returns normalized value in [-1, 1).
float RandomFloat(RngStream& rng) {
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
const uint32_t representation = exp | (rng() & mantissa_mask);
const float f12 = hwy::BitCastScalar<float>(representation);
HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
HWY_DASSERT(-1.0f <= f && f < 1.0f);
return f;
}
// `raw` holds the decompressed values, so that the test measures only the // `raw` holds the decompressed values, so that the test measures only the
// error from the Dot algorithms, not the compression. // error from the Dot algorithms, not the compression.
template <typename Packed> template <typename Packed>
@ -1126,7 +1114,7 @@ void TestAllDot() {
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
ParallelFor( ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest, Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
[&](size_t rep, size_t thread) { [&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread); float* HWY_RESTRICT pb = b.Row(thread);

View File

@ -837,10 +837,11 @@ class MMImpl {
hwy::platform::InvariantTicksPerSecond(); hwy::platform::InvariantTicksPerSecond();
const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", fprintf(
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), stderr,
cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), "%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu\n",
cfg.InnerTasks()); M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), cfg.MC(),
cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), cfg.InnerTasks());
} }
if (HWY_UNLIKELY(env.print_best && tuner.Best())) { if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
const auto ratio = [&tuner](uint64_t ticks) -> double { const auto ratio = [&tuner](uint64_t ticks) -> double {
@ -850,7 +851,8 @@ class MMImpl {
const MMConfig& best = *tuner.Best(); const MMConfig& best = *tuner.Best();
fprintf( fprintf(
stderr, stderr,
"\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", "\n%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu, "
"%.2fx,%.2fx\n",
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
best.KC(), best.NC(), StringFromOrder(best.Order()), best.KC(), best.NC(), StringFromOrder(best.Order()),
best.InnerTasks(), ratio(tuner.WorstMinTicks()), best.InnerTasks(), ratio(tuner.WorstMinTicks()),
@ -906,8 +908,8 @@ class MMLoops {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M
const IndexRange& range_kc = args.ranges_kc.Range(0); const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K
parallel.ForN( parallel.ForN(
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
@ -941,7 +943,7 @@ class MMLoops {
const MMArgs& args) { const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M
parallel.ForN(args.env.ctx, args.range_n, parallel.ForN(args.env.ctx, args.range_n,
MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks,
@ -977,7 +979,7 @@ class MMLoops {
const MMArgs& args) { const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0); const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K
parallel.ForRangesMC_NC( parallel.ForRangesMC_NC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
@ -1158,8 +1160,9 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
HWY_ASSERT(K <= MMEntireA::kMaxK); HWY_ASSERT(K <= MMEntireA::kMaxK);
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
MMImpl::EnsureAligned(A, cache.VectorBytes()); MMImpl::EnsureAligned(A, cache.VectorBytes());
tuner.SetCandidates( const size_t max_M = MMKeys::BucketM(M);
MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config)); tuner.SetCandidates(MMCandidates(cache, max_M, K, N, num_B, sizeof(BF16),
env.print_config));
} }
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();

View File

@ -21,6 +21,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <string>
#include <vector> #include <vector>
#include "util/allocator.h" #include "util/allocator.h"
@ -46,7 +47,9 @@ size_t RoundDownWithFloor(size_t value, size_t multiple) {
// multiple of `multiple`, or 0 if none exists. // multiple of `multiple`, or 0 if none exists.
size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
const size_t multiple) { const size_t multiple) {
HWY_DASSERT(end != 0 && dim != 0 && multiple != 0); HWY_DASSERT(end != 0);
HWY_DASSERT(dim != 0);
HWY_DASSERT(multiple != 0);
size_t prev = RoundDownWithFloor(end, multiple); size_t prev = RoundDownWithFloor(end, multiple);
// Avoid returning `end` if rounding down had no effect. // Avoid returning `end` if rounding down had no effect.
if (prev == end) prev -= multiple; if (prev == end) prev -= multiple;
@ -62,10 +65,10 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables. // and holds most of their arguments in member variables.
class GenerateCandidates { class GenerateCandidates {
public: public:
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N, GenerateCandidates(const CacheInfo& cache, size_t max_M, size_t K, size_t N,
size_t num_B, size_t sizeof_TC, bool print_config) size_t num_B, size_t sizeof_TC, bool print_config)
: cache_(cache), : cache_(cache),
M_(M), max_M_(max_M),
K_(K), K_(K),
N_(N), N_(N),
num_B_(num_B), num_B_(num_B),
@ -89,14 +92,14 @@ class GenerateCandidates {
for (size_t mc : MC(mr, kc, order)) { for (size_t mc : MC(mr, kc, order)) {
for (size_t nc : NC(mr, mc, kc, order)) { for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) { for (int inner_tasks : all_inner_tasks) {
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, const MMConfig config(max_M_, K_, N_, mr, mc, kc, nc,
nc_multiple_, order, inner_tasks); kc_multiple_, nc_multiple_, order,
const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); inner_tasks);
const size_t M_tasks = config.RangesOfMC(max_M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
// Blocks only make sense when there are multiple M tasks. // Do not use single-MC/KC order if there are multiple.
if (IsBlock(order) != (M_tasks > 1)) continue; if (IsOneMC(order) != (M_tasks == 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue; if (IsOneKC(order) != (K_tasks == 1)) continue;
candidates.push_back(config); candidates.push_back(config);
@ -114,6 +117,25 @@ class GenerateCandidates {
private: private:
using SizeVec = std::vector<size_t>; using SizeVec = std::vector<size_t>;
// Concatenate and print once because this can be called concurrently.
void MaybePrintSizes(size_t dim, size_t max, const char* caption,
const SizeVec& sizes) const {
if (!print_config_ || sizes.empty()) return;
std::string out("num_B ");
out += std::to_string(num_B_);
out += " (";
out += std::to_string(dim);
out += ", max ";
out += std::to_string(max);
out += ") ";
out += caption;
out += ": ";
for (size_t size : sizes) {
out += std::to_string(size) + " ";
}
fprintf(stderr, "%s\n", out.c_str());
}
// How many rows of A per call to `MMKernel::LoopKC`. Lower values may // How many rows of A per call to `MMKernel::LoopKC`. Lower values may
// be better for SIMD targets with fewer registers. // be better for SIMD targets with fewer registers.
SizeVec MR() const { SizeVec MR() const {
@ -125,14 +147,14 @@ class GenerateCandidates {
SizeVec all_mr; SizeVec all_mr;
all_mr.reserve(3); all_mr.reserve(3);
// AVX2's 16 registers are not enough for four rows, but SSE4 may benefit. // AVX2's 16 registers are not enough for four rows, but SSE4 may benefit.
if (M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR); if (max_M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR);
// Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also // Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also
// enable if not enough rows for 4. // enable if not enough rows for 4.
if (M_ >= 2 && (M_ < kMaxMR || (!is_sse && !is_wasm))) { if (max_M_ >= 2 && (max_M_ < kMaxMR || (!is_sse && !is_wasm))) {
all_mr.push_back(size_t{2}); all_mr.push_back(size_t{2});
} }
// Even SSE4 usually prefers 2 rows; only enable for single rows. // Even SSE4 usually prefers 2 rows; only enable for single rows.
if (M_ == 1) all_mr.push_back(size_t{1}); if (max_M_ == 1) all_mr.push_back(size_t{1});
HWY_ASSERT(!all_mr.empty()); HWY_ASSERT(!all_mr.empty());
return all_mr; return all_mr;
} }
@ -143,18 +165,26 @@ class GenerateCandidates {
for (size_t order_idx = 0;; ++order_idx) { for (size_t order_idx = 0;; ++order_idx) {
const MMOrder order = static_cast<MMOrder>(order_idx); const MMOrder order = static_cast<MMOrder>(order_idx);
if (StringFromOrder(order) == nullptr) return orders; // done if (StringFromOrder(order) == nullptr) return orders; // done
// 2D blocking is useless for a single row of M. // Multiple-MC is useless for a single row of M.
if (IsBlock(order) && M_ <= mr) continue; if (!IsOneMC(order) && max_M_ <= mr) continue;
// Conversely, N-only parallelism is uncompetitive for large M. // Conversely, N-only parallelism is uncompetitive for large M.
if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue; if (IsOneMC(order) && max_M_ >= 8 * mr) continue;
orders.push_back(order); orders.push_back(order);
} }
} }
// The number of A and B columns to read between updating `C`. // The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const { SizeVec KC(size_t mr, MMOrder order) const {
if (IsOneKC(order)) {
// A single KC range is infeasible when K exceeds the max. The caller
// will skip all configs with `order`.
if (K_ > kMaxKC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfKC`, this
// will be used in MC() and NC().
return SizeVec(1, K_);
}
// `LoopKC` handles up to `mr` rows of A. // `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr); const size_t rows_a = HWY_MIN(max_M_, mr);
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector // After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
// `buf`. To amortize the write cost, we want to maximize `kc`. However, it // `buf`. To amortize the write cost, we want to maximize `kc`. However, it
@ -186,7 +216,7 @@ class GenerateCandidates {
// If we can afford a single K task, that's usually best; only try one // If we can afford a single K task, that's usually best; only try one
// more. Otherwise, blocks may require smaller kc (more options). // more. Otherwise, blocks may require smaller kc (more options).
const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2; const size_t reps = (kc_max == K_) ? 1 : IsOneMC(order) ? 2 : 3;
size_t prev = kc_max; size_t prev = kc_max;
for (size_t rep = 0; rep < reps; ++rep) { for (size_t rep = 0; rep < reps; ++rep) {
@ -196,22 +226,27 @@ class GenerateCandidates {
} }
} }
if (print_config_ && all_kc.size() > 1) { MaybePrintSizes(K_, kc_max, "KC", all_kc);
fprintf(stderr, "num_B %zu: KC: ", num_B_);
for (size_t kc : all_kc) {
fprintf(stderr, "%zu ", kc);
}
fprintf(stderr, "\n");
}
return all_kc; return all_kc;
} }
// The number of (L2 resident) A rows for `A2C0` to loop over. // The number of (L2 resident) A rows for `A2C0` to loop over.
SizeVec MC(size_t mr, size_t kc, MMOrder order) const { SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
if (max_M_ <= mr) return SizeVec(1, max_M_);
if (IsOneMC(order)) {
// A single MC range is infeasible when M exceeds the max. The caller
// will skip all configs with `order`.
if (max_M_ > kMaxMC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfMC`, this
// will be used in NC().
return SizeVec(1, max_M_);
}
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive. // it is typically inclusive.
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16)); const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
// `kc` was chosen to fit in L1, hence this should not exceed L2.
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside // packed B. We want `mc * kc` elements of A to fit in L2, alongside
@ -219,35 +254,45 @@ class GenerateCandidates {
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes(); const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, max_M_);
mc_max = HWY_MIN(mc_max, M_); HWY_ASSERT(mc_max != 0);
mc_max = hwy::RoundDownTo(mc_max, mr);
SizeVec all_mc(1, mc_max); SizeVec all_mc;
// Larger MC is better for non-blocks, otherwise we want more small options, all_mc.reserve(6);
// especially for two B.
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
size_t prev = mc_max; const size_t rounded_M = HWY_MAX(mr, hwy::RoundDownTo(max_M_, mr));
for (size_t rep = 0; rep < reps; ++rep) { size_t prev = hwy::RoundDownTo(mc_max, mr);
prev = PrevDivisor(1, prev, M_, mr);
if (prev >= mc_max || prev == 0) break; // If mc_max is large enough, allow using the whole range without rounding
// down (which may require two ranges).
if (mc_max == max_M_ && (max_M_ % mr) != 0) {
all_mc.push_back(max_M_);
// The next option should be considerably smaller than `max_M_`.
prev = HWY_MAX(mr, hwy::RoundDownTo(3 * prev / 4, mr));
} else {
all_mc.push_back(prev); all_mc.push_back(prev);
} }
// Blocks: largest is not useful. // We know `order` is multiple MC, where more/smaller values of `mc` are
if (IsBlock(order) && all_mc.size() > 1) { // helpful, especially for two B, hence add iterations.
all_mc.erase(all_mc.begin(), all_mc.begin() + 1); const size_t reps = 2 + num_B_;
} for (size_t rep = 0; rep < reps; ++rep) {
prev = PrevDivisor(mr, prev, rounded_M, mr);
if (print_config_ && all_mc.size() > 1) { if (prev == 0) break; // none found
fprintf(stderr, "num_B %zu: MC: ", num_B_); if (prev == mr) {
for (size_t mc : all_mc) { if (all_mc.back() != prev) all_mc.push_back(prev);
fprintf(stderr, "%zu ", mc); break;
} }
fprintf(stderr, "\n"); if (prev <= mc_max / 8) break;
all_mc.push_back(prev);
} }
if (all_mc.size() <= 2) {
if (max_M_ > mr) all_mc.push_back(max_M_ / 2);
if (mc_max > mr) all_mc.push_back(mc_max / 2);
}
MaybePrintSizes(max_M_, mc_max, "MC", all_mc);
return all_mc; return all_mc;
} }
@ -257,7 +302,7 @@ class GenerateCandidates {
// Only if there will be reuse of B: choose the largest `nc_max` (C cols) // Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise, // such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
// leave it unbounded. // leave it unbounded.
if (M_ > mr) { if (max_M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC); nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
} }
@ -271,8 +316,8 @@ class GenerateCandidates {
nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_); nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_);
} }
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`. // Single-MC calls `ForNP`, which ignores `range_nc`.
if (!IsBlock(order)) return SizeVec(1, N_); if (IsOneMC(order)) return SizeVec(1, N_);
SizeVec all_nc(1, nc_max); SizeVec all_nc(1, nc_max);
@ -282,7 +327,7 @@ class GenerateCandidates {
// hence autotune a wider range of nc than the other dimensions. // hence autotune a wider range of nc than the other dimensions.
size_t reps = 9 + num_B_; size_t reps = 9 + num_B_;
// For small M, we can afford larger NC, hence allow fewer small options. // For small M, we can afford larger NC, hence allow fewer small options.
if (M_ <= 2 * mr) reps -= 1; if (max_M_ <= 2 * mr) reps -= 1;
size_t prev = nc_max; size_t prev = nc_max;
for (size_t rep = 0; rep < reps; ++rep) { for (size_t rep = 0; rep < reps; ++rep) {
@ -302,14 +347,7 @@ class GenerateCandidates {
all_nc.begin() + HWY_MIN(want_delete, max_delete)); all_nc.begin() + HWY_MIN(want_delete, max_delete));
} }
if (print_config_ && all_nc.size() > 1) { MaybePrintSizes(N_, nc_max, "NC", all_nc);
fprintf(stderr, "num_B %zu: NC: ", num_B_);
for (size_t nc : all_nc) {
fprintf(stderr, "%zu ", nc);
}
fprintf(stderr, "\n");
}
return all_nc; return all_nc;
} }
@ -319,8 +357,8 @@ class GenerateCandidates {
std::vector<int> inner_tasks; std::vector<int> inner_tasks;
inner_tasks.reserve(3); inner_tasks.reserve(3);
inner_tasks.push_back(1); inner_tasks.push_back(1);
// Blocks have one task per mc/nc range and ignore this parameter. // Multiple-MC have one task per mc/nc range and ignore this parameter.
if (!IsBlock(order)) { if (IsOneMC(order)) {
inner_tasks.push_back(2); inner_tasks.push_back(2);
inner_tasks.push_back(4); inner_tasks.push_back(4);
} }
@ -328,7 +366,7 @@ class GenerateCandidates {
} }
const CacheInfo& cache_; const CacheInfo& cache_;
const size_t M_; const size_t max_M_;
const size_t K_; const size_t K_;
const size_t N_; const size_t N_;
const size_t num_B_; const size_t num_B_;
@ -343,10 +381,11 @@ class GenerateCandidates {
} // namespace } // namespace
// Facade to avoid exposing `GenerateCandidates` in the header. // Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K, std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t max_M,
size_t N, size_t num_B, size_t sizeof_TC, size_t K, size_t N, size_t num_B,
bool print_config) { size_t sizeof_TC, bool print_config) {
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)(); return GenerateCandidates(cache, max_M, K, N, num_B, sizeof_TC,
print_config)();
} }
MatMulEnv::MatMulEnv(ThreadingContext& ctx) MatMulEnv::MatMulEnv(ThreadingContext& ctx)

View File

@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
// Policy classes for parallelism, implementing some of `ParallelismStrategy`. // Policy classes for parallelism, implementing some of `Parallelism`.
struct MMParallelNone { struct MMParallelNone {
template <class Func> template <class Func>
@ -103,18 +103,14 @@ struct MMParallelWithinCluster {
template <class Func> template <class Func>
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, size_t cluster_idx, const Func& func) const { size_t inner_tasks, size_t cluster_idx, const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMClusterForN);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); ParallelPartitionWithinCluster(
const size_t base = ctx.Worker(cluster_idx); range_n, n_multiple, inner_tasks, ctx, cluster_idx, caller,
[&](const IndexRange& worker_range, size_t worker) {
const IndexRangePartition ranges_n = StaticPartition( func(worker_range, worker);
range_n, cluster.NumWorkers() * inner_tasks, n_multiple); });
ParallelizeOneRange(ranges_n, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForN),
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
});
} }
template <class Func> template <class Func>
@ -122,79 +118,56 @@ struct MMParallelWithinCluster {
const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx, const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const { const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const hwy::pool::Caller caller =
const size_t base = ctx.Worker(cluster_idx); ctx.pool_callers.Get(Callers::kMMClusterForMCNC);
// Low-batch: avoid Divide/Remainder. // We are running on one pool, hence collapse into a 1D range.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
ParallelizeOneRange(ranges_nc, cluster, const auto get_mc = [&](uint64_t task) {
ctx.pool_callers.Get(Callers::kMMClusterForMCNC), return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
[&](const IndexRange& range_nc, size_t worker) { };
func(ranges_mc.Range(0), range_nc, base + worker); const auto get_nc = [&](uint64_t task) {
}); return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
} else { };
ParallelizeTwoRanges( const size_t num_tasks = ranges_mc.NumTasks() * ranges_nc.NumTasks();
ranges_mc, ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC), ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](uint64_t task, size_t worker) {
size_t worker) { func(range_mc, range_nc, base + worker); }); func(get_mc(task), get_nc(task), worker);
} });
} }
template <class Func> template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const { size_t cluster_idx, const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const hwy::pool::Caller caller =
const size_t base = ctx.Worker(cluster_idx); ctx.pool_callers.Get(Callers::kMMClusterForMC);
cluster.Run( ParallelForWithinCluster(
range_mc.begin(), range_mc.end(), range_mc.Num(), ctx, cluster_idx, caller,
ctx.pool_callers.Get(Callers::kMMClusterForMC), [&](uint64_t i, size_t worker) { func(range_mc.begin() + i, worker); });
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
} }
}; };
struct MMParallelHierarchical { struct MMParallelHierarchical {
// Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is // Similar to `HierarchicalParallelFor`, but over *sub-ranges* of B rows in
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. // `range_n` governed by `n_multiple` and `inner_tasks`.
template <class Func> template <class Func>
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx, size_t inner_tasks, size_t caller_cluster_idx,
const Func& func) const { const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
HWY_DASSERT(caller_cluster_idx == 0); HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN); const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
// Single cluster: parallel-for over static partition of `range_n`. // Assign clusters (if any) a sub-range of `range_n` (typically hundreds).
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); ParallelPartitionAcrossClusters(
const size_t num_clusters = all_clusters.NumWorkers(); range_n, n_multiple, /*inner_tasks=*/1, ctx, caller,
if (num_clusters == 1) { [&](const IndexRange& cluster_range, size_t cluster_idx) {
const size_t cluster_idx = 0; ParallelPartitionWithinCluster(
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); cluster_range, n_multiple, inner_tasks, ctx, cluster_idx, caller,
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
ranges_n, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
}
// Assign each cluster a sub-range of `range_n` (typically hundreds).
const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
ranges_n, all_clusters, caller,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(
worker_ranges, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + worker); func(worker_range, worker);
}); });
}); });
} }
@ -205,69 +178,56 @@ struct MMParallelHierarchical {
void ForRangesMC_NC(ThreadingContext& ctx, void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, const IndexRangePartition& ranges_nc,
HWY_MAYBE_UNUSED size_t caller_cluster_idx, size_t caller_cluster_idx, const Func& func) const {
const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0); HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
const hwy::pool::Caller caller = const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMHierForMCNC); ctx.pool_callers.Get(Callers::kMMHierForMCNC);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); // Collapse two range indices into a 1D range for better load-balancing,
// `all_clusters` is a pool with one worker per cluster in a package. // because `ranges_mc` may just have one task.
const size_t num_clusters = all_clusters.NumWorkers(); const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
// Single (big) cluster: collapse two range indices into one parallel-for const auto get_mc = [&](uint64_t task) {
// to reduce the number of fork-joins. return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
if (num_clusters == 1) { };
const size_t cluster_idx = 0; const auto get_nc = [&](uint64_t task) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
// Low-batch: avoid Divide/Remainder. };
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { const IndexRange all_range(0, ranges_mc.NumTasks() * ranges_nc.NumTasks());
return ParallelizeOneRange(
ranges_nc, cluster, caller,
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, worker);
});
} else {
return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, caller,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, worker); });
}
}
// Multiple clusters: N across clusters (both are usually the larger), and ParallelPartitionAcrossClusters(
// M within each cluster. We assume auto-tuning finds small MC/NC tasks. all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller,
ParallelizeOneRange( [&](const IndexRange& cluster_range, size_t cluster_idx) {
ranges_nc, all_clusters, caller, ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx,
[&](const IndexRange range_nc, size_t cluster_idx) { caller, [&](uint64_t i, size_t worker) {
const size_t cluster_base = ctx.Worker(cluster_idx); const size_t task =
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); cluster_range.begin() + i;
ParallelizeOneRange(ranges_mc, cluster, caller, func(get_mc(task), get_nc(task), worker);
[&](const IndexRange& range_mc, size_t worker) { });
func(range_mc, range_nc, cluster_base + worker);
});
}); });
} }
// Calls `func(row_a, worker)` in parallel. // No multiple/inner_tasks, so this is just HierarchicalParallelFor.
template <class Func> template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t caller_cluster_idx, const Func& func) const { size_t caller_cluster_idx, const Func& func) const {
HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC, HWY_DASSERT(caller_cluster_idx == 0);
[&](size_t task, size_t worker) { (void)caller_cluster_idx;
func(range_mc.begin() + task, worker); HierarchicalParallelFor(
}); range_mc.Num(), ctx, Callers::kMMHierForMC,
[&](size_t i, size_t worker) { func(range_mc.begin() + i, worker); });
} }
}; };
template <class Func, typename... Args> template <class Func, typename... Args>
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, void DispatchParallelism(Parallelism parallelism, const Func& func,
Args&&... args) { Args&&... args) {
switch (parallelism) { switch (parallelism) {
case ParallelismStrategy::kNone: case Parallelism::kNone:
return func(MMParallelNone(), std::forward<Args>(args)...); return func(MMParallelNone(), std::forward<Args>(args)...);
case ParallelismStrategy::kWithinCluster: case Parallelism::kWithinCluster:
return func(MMParallelWithinCluster(), std::forward<Args>(args)...); return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
case ParallelismStrategy::kHierarchical: case Parallelism::kHierarchical:
return func(MMParallelHierarchical(), std::forward<Args>(args)...); return func(MMParallelHierarchical(), std::forward<Args>(args)...);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
@ -371,8 +331,8 @@ void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
} }
} }
static inline bool IsBlock(MMOrder order) { static inline bool IsOneMC(MMOrder order) {
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT; return order == MMOrder::kNT || order == MMOrder::kNT_K;
} }
static inline bool IsOneKC(MMOrder order) { static inline bool IsOneKC(MMOrder order) {
@ -421,6 +381,8 @@ static inline const char* StringFromParA(MMParA par_a) {
// `mc` := A rows such that `kc` columns fit in L2, // `mc` := A rows such that `kc` columns fit in L2,
// `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C. // `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C.
// Also includes loop order and task granularity. // Also includes loop order and task granularity.
//
// This is shared by multiple M which return the same `BucketM`.
#pragma pack(push, 1) #pragma pack(push, 1)
class MMConfig { class MMConfig {
public: public:
@ -428,8 +390,8 @@ class MMConfig {
// `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops. // `MMOrder` is how to parallelize the outer loops.
// `inner_tasks` chooses the within-cluster task granularity in `ForN`. // `inner_tasks` chooses the within-cluster task granularity in `ForN`.
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, MMConfig(size_t M, size_t K, size_t N, size_t mr, size_t mc, size_t kc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order,
int inner_tasks) int inner_tasks)
: mr_(static_cast<uint32_t>(mr)), : mr_(static_cast<uint32_t>(mr)),
mc_(static_cast<uint32_t>(mc)), mc_(static_cast<uint32_t>(mc)),
@ -441,11 +403,7 @@ class MMConfig {
inner_tasks_(static_cast<uint8_t>(inner_tasks)), inner_tasks_(static_cast<uint8_t>(inner_tasks)),
reserved_{} { reserved_{} {
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4); HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
if (mc % mr != 0) { // Some models have K which are not multiples of `kc_multiple`.
HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr);
}
// Do not warn for single-kc tasks; some models unfortunately have K which
// are not multiples of `kc_multiple`.
if (kc != K && (kc % kc_multiple) != 0) { if (kc != K && (kc % kc_multiple) != 0) {
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple); HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
} }
@ -457,11 +415,21 @@ class MMConfig {
} }
// Splits M/N into blocks which are visited sequentially or in parallel. // Splits M/N into blocks which are visited sequentially or in parallel.
// K is always sequential, see `MMOrder`.
IndexRangePartition RangesOfMC(size_t M) const { IndexRangePartition RangesOfMC(size_t M) const {
return MaxSizePartition(IndexRange(0, M), mc_, mr_); if (IsOneMC(order_)) {
// Must have exactly one M range/tile, regardless of `mr_` and `mc_`.
return IndexRangePartition(M);
}
const size_t mc = HWY_MIN(M, MC());
const size_t mr = HWY_MIN(M, MR());
return MaxSizePartition(IndexRange(0, M), mc, mr);
} }
// K is either a single range, or a sequential loop.
IndexRangePartition RangesOfKC(size_t K) const { IndexRangePartition RangesOfKC(size_t K) const {
if (IsOneKC(order_)) {
// Must have exactly one K range/tile, regardless of `kc_`.
return IndexRangePartition(K);
}
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_); return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
} }
IndexRangePartition RangesOfNC(size_t N) const { IndexRangePartition RangesOfNC(size_t N) const {
@ -488,7 +456,7 @@ class MMConfig {
uint32_t kc_multiple_; uint32_t kc_multiple_;
MMOrder order_; MMOrder order_;
uint8_t inner_tasks_; uint8_t inner_tasks_;
HWY_MAYBE_UNUSED uint8_t reserved_[6]; HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t reserved_[6];
}; };
static_assert(sizeof(MMConfig) == 32); // for faster indexing static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop) #pragma pack(pop)
@ -597,26 +565,27 @@ class MMAutoTune {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
// but choosing the same config for a larger M can result in multiple MC ranges.
// Thus M less than this must have unique keys/configs.
HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8;
// Map of previously seen dimensions to index via linear search. // Map of previously seen dimensions to index via linear search.
class MMKeys { class MMKeys {
// Group batch size into buckets to reduce #auto-tunes.
static size_t BucketM(size_t M) {
if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above.
if (M <= 128) return 128;
return 512;
}
public: public:
using Key = uint64_t; using Key = uint64_t;
// KeyFromDims will only return this if all dims are zero, which is invalid. // KeyFromDims will only return this if all dims are zero, which is invalid.
static constexpr Key kPadding = 0; static constexpr Key kPadding = 0;
// Returns the maximum permissible M in the bucket, for grouping batch sizes
// into buckets to reduce #auto-tunes.
static size_t BucketM(size_t M) {
HWY_DASSERT(M != 0);
// Small M: 1..3, 4..7, 8..15, etc. share the same config.
if (M < 64) return M | (kMaxMR - 1);
// Larger M use power of two buckets: 64..127, 128..255, etc.
const size_t floor_log2_M =
31 - hwy::Num0BitsAboveMS1Bit_Nonzero32(static_cast<uint32_t>(M));
const size_t min_M = size_t{1} << floor_log2_M;
HWY_DASSERT(min_M <= M && M < 2 * min_M);
return 2 * min_M - 1;
}
// Compresses the dimensions into a single Key for faster comparison. // Compresses the dimensions into a single Key for faster comparison.
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) { static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
@ -747,7 +716,7 @@ class MMOptions {
const void* opaque = nullptr; const void* opaque = nullptr;
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; Parallelism parallelism = Parallelism::kHierarchical;
}; };
// Arguments to MatMul() that are independent of the A/B/C types. Reduces // Arguments to MatMul() that are independent of the A/B/C types. Reduces

View File

@ -195,9 +195,10 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c = const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange( ParallelForAcrossClusters(
get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest), get_col_c.NumTasks(), env.ctx, env.ctx.pool_callers.Get(Callers::kTest),
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { [&](size_t range_idx, size_t cluster_idx) HWY_ATTR {
const IndexRange cols_c = get_col_c.Range(range_idx);
for (size_t r : all_rows_c) { for (size_t r : all_rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r); TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) { for (size_t c : cols_c) {

View File

@ -25,9 +25,11 @@
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <utility>
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"
#include "ops/ops.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // TokenAndProb, RngStream #include "util/basics.h" // TokenAndProb, RngStream
#include "util/mat.h" #include "util/mat.h"
@ -61,6 +63,9 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Computes C = A * B + add via MatMulStatic.
// This function uses CallUpcasted to dispatch to the correct MatMulStatic
// instantiation based on the runtime type of B.
template <typename TA, typename TC> template <typename TA, typename TC>
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B, MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
@ -497,10 +502,10 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out)); activations.DebugCheckSameShape(out);
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
cluster_idx, Callers::kOpsRMSNormBatched, cluster_idx, Callers::kOpsRMSNormBatched,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
@ -517,7 +522,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
Callers::kOpsRMSNormInplaceBatched, Callers::kOpsRMSNormInplaceBatched,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
@ -550,7 +555,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) { Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker); AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
}); });
@ -1122,9 +1127,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2. // TODO: support bf16 logits using Decompress2.
// Computes softmax probabilities for the given logits, normalizing in-place.
// The calculation is numerically stable, using the max-subtraction trick to
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
// If temperature is provided and not 1.0, each intermediate exp() result is
// divided by temperature before normalization; however, this division by
// temperature cancels out during the final normalization step, meaning
// temperature currently has no effect on the output probabilities.
// @param logits In-out: on input, contains logits; on output, overwritten with
// probabilities.
// @param ctx Input: threading context for parallelism and profiling.
// @param worker Input: worker thread index.
// @param temperature Input: softmax temperature.
// @param softmax_max_out Optional output: if not null, stores the max logit
// value.
// @param softmax_d_out Optional output: if softmax_max is not null, this must
// not be null and stores the sum of exp(logit - max).
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
const size_t worker, const size_t worker, float temperature = 1.0f,
float temperature = 1.0f) { const SMOptions& sm_options = {}) {
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax); GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
HWY_DASSERT(logits.size() != 0); HWY_DASSERT(logits.size() != 0);
@ -1168,6 +1189,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, logits.data(), logits.size()); MulByConst(mul, logits.data(), logits.size());
if (sm_options.max_out) {
*sm_options.max_out = hn::GetLane(vmax);
*sm_options.d_out = sum_exp;
}
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
@ -1290,7 +1315,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos, const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx, size_t cluster_idx = 0) { ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return; if (cap == 0.0f) return;
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
Callers::kOpsMaybeLogitsSoftCapBatched, Callers::kOpsMaybeLogitsSoftCapBatched,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) { if (non_eos.Get(task)) {

View File

@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
return inv_timescale; return inv_timescale;
} }
struct SMOptions {
float* HWY_RESTRICT max_out = nullptr;
float* HWY_RESTRICT d_out = nullptr;
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_

View File

@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
#include "compression/types.h" #include "compression/types.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -38,7 +37,6 @@
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/profiler.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// clang-format off // clang-format off
@ -348,6 +346,51 @@ void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float()); hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
} }
class TestSoftmaxState {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
if (misalign_b == 0) return;
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe);
T* x = px.get() + misalign_a;
T* initial_logits = pe.get() + misalign_a;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
initial_logits[i] = x[i];
}
float softmax_max;
float softmax_d;
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
{.max_out = &softmax_max, .d_out = &softmax_d});
const float maxval =
*std::max_element(initial_logits, initial_logits + count);
float sum_exp = 0.0f;
for (size_t i = 0; i < count; ++i) {
sum_exp += std::exp(initial_logits[i] - maxval);
}
ASSERT_NEAR(softmax_max, maxval, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
}
};
void TestAllSoftmaxState() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
}
template <size_t k> template <size_t k>
struct TestCreateDistribution { struct TestCreateDistribution {
void operator()(hwy::RandomState& rng) { void operator()(hwy::RandomState& rng) {
@ -456,7 +499,7 @@ void TestRopeAndMulBy() {
x.Row(0)[i] = random_float(); x.Row(0)[i] = random_float();
} }
const float qmul = AttentionActivations::ChooseQueryScale(config); const float qmul = ChooseQueryScale(config);
constexpr float kmul = 1.0f; constexpr float kmul = 1.0f;
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator); MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
@ -771,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);

View File

@ -29,6 +29,7 @@ cc_test(
deps = [ deps = [
":image", ":image",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
], ],
) )

View File

@ -37,8 +37,6 @@
namespace gcpp { namespace gcpp {
namespace { namespace {
// Hardcoded for PaliGemma ViT input.
constexpr size_t kPatchSize = 14;
// Returns the linearly scaled index in [0, to_size) closest to the // Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size). // value in [0, from_size).
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
} }
// Image.data() is H x W x 3. // Image.data() is H x W x 3.
// We want the N-th patch of size kPatchSize x kPatchSize x 3. // We want the N-th patch of size patch_dim x patch_dim x 3.
void Image::GetPatch(size_t patch_num, float* patch) const { void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const {
PROFILER_FUNC; PROFILER_FUNC;
constexpr size_t kNumChannels = 3; constexpr size_t kNumChannels = 3;
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float)); constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel); const size_t patch_dim = div_patch_dim.GetDivisor();
const size_t kDataSize = width_ * height_ * kNumChannels; const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
HWY_ASSERT(size() == kDataSize); HWY_ASSERT(size() == width_ * height_ * kNumChannels);
HWY_ASSERT(width_ % kPatchSize == 0); HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
HWY_ASSERT(height_ % kPatchSize == 0); HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize; const size_t patches_x = div_patch_dim.Divide(width_);
size_t patch_y = patch_num / kNumPatchesPerRow; size_t patch_y = patch_num / patches_x;
size_t patch_x = patch_num % kNumPatchesPerRow; size_t patch_x = patch_num % patches_x;
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize); HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow); HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
patch_y *= kPatchSize; patch_y *= patch_dim;
patch_x *= kPatchSize; patch_x *= patch_dim;
// Move `out` and `in` to the start of the patch. // Move `out` and `in` to the start of the patch.
char* out = reinterpret_cast<char*>(patch); char* out = reinterpret_cast<char*>(patch);
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
in += (((patch_y * width_) + patch_x) * kBytesPerPixel); in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
// Copy the patch one row at a time. // Copy the patch one row at a time.
for (size_t y = 0; y < kPatchSize; ++y) { for (size_t y = 0; y < patch_dim; ++y) {
std::memcpy(out, in, kBytesPerRow); std::memcpy(out, in, bytes_per_row);
out += kBytesPerRow; out += bytes_per_row;
in += in_bytes_to_next_row; in += in_bytes_to_next_row;
} }
} }

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // Divisor
namespace gcpp { namespace gcpp {
@ -44,11 +45,12 @@ class Image {
bool WriteBinary(const std::string& filename) const; bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number in `patch`. // Stores the patch for the given patch number in `patch`.
// Patches are numbered in usual raster-order. E.g. for an image of size // Patches are numbered in usual raster-order. E.g. for an image of size
// 224 x 224, there are 16 x 16 = 256 patches. // 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats. // `patch` should have space for at least patch_dim * patch_dim * 3.
// Requires that Normalize() has been called and that the image width and // Requires that Normalize() has been called and that the image width and
// height are multiples of 14. // height are multiples of patch_dim.
void GetPatch(size_t patch_num, float* patch) const; void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const;
float *data() { return data_.data(); } float *data() { return data_.data(); }
const float *data() const { return data_.data(); } const float *data() const { return data_.data(); }

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "hwy/base.h"
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122)); EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
// Extract two patches. // Extract two patches.
float patch[588]; float patch[588];
image.GetPatch(0, patch); const hwy::Divisor div_patch_dim(14);
image.GetPatch(0, div_patch_dim, patch);
EXPECT_EQ(patch[0], Normalize(160)); EXPECT_EQ(patch[0], Normalize(160));
EXPECT_EQ(patch[1], Normalize(184)); EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188)); EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch); image.GetPatch(18, div_patch_dim, patch);
// Check the first row of the patch. // Check the first row of the patch.
for (size_t i = 0; i < 14 * 3; ++i) { for (size_t i = 0; i < 14 * 3; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]); EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
// Extract two patches. // Extract two patches.
const size_t kPatchValues = 14 * 14 * 3; // = 588 const size_t kPatchValues = 14 * 14 * 3; // = 588
float patch[kPatchValues]; float patch[kPatchValues];
const hwy::Divisor div_patch_dim(14);
// Patch 0 is just the "start" of the image. // Patch 0 is just the "start" of the image.
image.GetPatch(0, patch); image.GetPatch(0, div_patch_dim, patch);
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6); EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6); EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6); EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its // The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
// pixel coordinates are offset by (14, 28). // pixel coordinates are offset by (14, 28).
image.GetPatch(6, patch); image.GetPatch(6, div_patch_dim, patch);
for (size_t n = 0; n < kPatchValues; ++n) { for (size_t n = 0; n < kPatchValues; ++n) {
size_t k = n % 3; size_t k = n % 3;
size_t j = ((n - k) / 3) % 14; size_t j = ((n - k) / 3) % 14;

View File

@ -21,10 +21,9 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "io/io.h" #include "paligemma/paligemma_helper.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "paligemma/paligemma_helper.h"
// This test can be run manually with the downloaded PaliGemma weights. // This test can be run manually with the downloaded PaliGemma weights.
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
@ -72,9 +71,12 @@ TEST_F(PaliGemmaTest, QueryObjects) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaEnv env(argc, argv); gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
gcpp::GemmaEnv env(args);
gcpp::s_env = &env; gcpp::s_env = &env;
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();

View File

@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id) .def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names) .def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
.def_readwrite("internal", &ModelConfig::internal) .def_readwrite("internal", &ModelConfig::internal)
.def_readwrite("use_global_timescale",
&ModelConfig::use_global_timescale)
.def("add_layer_config", &ModelConfig::AddLayerConfig, .def("add_layer_config", &ModelConfig::AddLayerConfig,
arg("layer_config")) arg("layer_config"))

View File

@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
// Wrapper around GemmaEnv to expose to Python. // Wrapper around GemmaEnv to expose to Python.
class GemmaModel { class GemmaModel {
public: public:
GemmaModel(const gcpp::LoaderArgs& loader, GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
const gcpp::ThreadingArgs& threading,
const gcpp::InferenceArgs& inference)
: env_(loader, threading, inference), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the // Generates a single example, given a prompt and a callback to stream the
// generated tokens. // generated tokens.
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
py::class_<GemmaModel>(mod, "GemmaModel") py::class_<GemmaModel>(mod, "GemmaModel")
.def(py::init([](const std::string& tokenizer, const std::string& weights, .def(py::init([](const std::string& tokenizer, const std::string& weights,
size_t max_threads) { size_t max_threads) {
const gcpp::LoaderArgs loader(tokenizer, weights);
gcpp::ThreadingArgs threading; gcpp::ThreadingArgs threading;
threading.max_lps = max_threads; threading.max_lps = max_threads;
gcpp::InferenceArgs inference; gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512; inference.max_generated_tokens = 512;
auto gemma =
std::make_unique<GemmaModel>(loader, threading, inference); const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
threading, inference);
auto gemma = std::make_unique<GemmaModel>(args);
if (!gemma->ModelIsLoaded()) { if (!gemma->ModelIsLoaded()) {
throw std::invalid_argument("Could not load model."); throw std::invalid_argument("Could not load model.");
} }

View File

@ -22,6 +22,7 @@
#include <algorithm> // std::transform #include <algorithm> // std::transform
#include <string> #include <string>
#include <vector>
#include "io/io.h" // Path #include "io/io.h" // Path
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
@ -29,6 +30,56 @@
namespace gcpp { namespace gcpp {
// For checking which args were not matched/consumed. Passed to each `*Args`
// ctor that parses argc/argv to ensure that their args are tracked, without
// requiring global state.
class ConsumedArgs {
public:
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
// We assume argc >= 1, because argv[0] is the binary name. That allows us
// to signal "called AbortIfUnconsumed" with an empty vector.
HWY_ASSERT(!consumed_.empty());
}
~ConsumedArgs() {
if (HWY_UNLIKELY(!consumed_.empty())) {
HWY_ABORT("AbortIfUnconsumed was not called.");
}
}
void NotifyConsumed(size_t idx) {
HWY_ASSERT(idx < consumed_.size());
HWY_ASSERT(consumed_[idx] == 0);
consumed_[idx] = 1;
}
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
size_t FirstUnconsumed() {
// Ignore argv[0], which is the binary name.
for (size_t i = 1; i < consumed_.size(); ++i) {
if (HWY_UNLIKELY(consumed_[i] == 0)) {
consumed_.clear();
return i;
}
}
consumed_.clear();
return 0;
}
void AbortIfUnconsumed() {
const size_t i = FirstUnconsumed();
if (HWY_UNLIKELY(i != 0)) {
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
}
}
private:
char** argv_;
std::vector<uint8_t> consumed_;
};
// Args is a class that provides a ForEach member function which visits each of // Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to // its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor), // initialize values to their defaults (passed as an argument to the visitor),
@ -93,12 +144,14 @@ class ArgsBase {
// consider adding a hash-map to speed this up. // consider adding a hash-map to speed this up.
class ParseVisitor { class ParseVisitor {
public: public:
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
: argc_(argc), argv_(argv), consumed_(consumed) {}
template <typename T> template <typename T>
void operator()(T& t, const char* name, const T& /*init*/, void operator()(T& t, const char* name, const T& /*init*/,
const char* /*help*/, int /*print_verbosity*/ = 0) const { const char* /*help*/, int /*print_verbosity*/ = 0) const {
const std::string prefixed = std::string("--") + name; const std::string prefixed = std::string("--") + name;
const std::string prefixed_eq = prefixed + "=";
for (int i = 1; i < argc_; ++i) { for (int i = 1; i < argc_; ++i) {
if (std::string(argv_[i]) == prefixed) { if (std::string(argv_[i]) == prefixed) {
if (i + 1 >= argc_) { if (i + 1 >= argc_) {
@ -107,6 +160,16 @@ class ArgsBase {
if (!SetValue(argv_[i + 1], t)) { if (!SetValue(argv_[i + 1], t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
} }
consumed_.NotifyConsumed(i);
consumed_.NotifyConsumed(i + 1);
return;
}
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
const char* value = argv_[i] + prefixed_eq.length();
if (!SetValue(value, t)) {
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
}
consumed_.NotifyConsumed(i);
return; return;
} }
} }
@ -173,8 +236,9 @@ class ArgsBase {
} }
} }
int argc_; const int argc_;
char** argv_; char** const argv_;
ConsumedArgs& consumed_;
}; // ParseVisitor }; // ParseVisitor
template <class Visitor> template <class Visitor>
@ -203,15 +267,15 @@ class ArgsBase {
ForEach(visitor); ForEach(visitor);
} }
void Parse(int argc, char* argv[]) { void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
ParseVisitor visitor(argc, argv); ParseVisitor visitor(argc, argv, consumed);
ForEach(visitor); ForEach(visitor);
} }
// For convenience, enables single-line constructor. // For convenience, enables single-line constructor.
void InitAndParse(int argc, char* argv[]) { void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
Init(); Init();
Parse(argc, argv); Parse(argc, argv, consumed);
} }
}; };

View File

@ -33,6 +33,9 @@ namespace gcpp {
// For hwy::BitSet4096. Note that KVs are extremely large for such batches. // For hwy::BitSet4096. Note that KVs are extremely large for such batches.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
// Multiplier so a u64 occupies an entire cache line; avoids false sharing.
HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
static inline const char* ToString(Tristate t) { static inline const char* ToString(Tristate t) {

View File

@ -181,7 +181,15 @@ class MatPtr : public IFields {
Extents2D Extents() const { return Extents2D(Rows(), cols_); } Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; } bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const { bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_; return Rows() == other.Rows() && Cols() == other.Cols();
}
void DebugCheckSameShape(const MatPtr& other) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (!SameShape(other)) {
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
Rows(), Cols(), other.Rows(), Cols());
}
}
} }
// Future calls to `Rows()` during this class' lifetime (not serialized) // Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for // will return this value. Used to set the actual number of rows for
@ -284,6 +292,9 @@ class MatPtrT : public MatPtr {
public: public:
using T = MatT; using T = MatT;
// Default constructor for use with uninitialized views.
MatPtrT() = default;
// Called by `MatStorageT`. // Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents) MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {} : MatPtr(name, TypeEnum<MatT>(), extents) {}
@ -296,7 +307,10 @@ class MatPtrT : public MatPtr {
if (GetType() == Type::kUnknown) { if (GetType() == Type::kUnknown) {
SetType(TypeEnum<MatT>()); SetType(TypeEnum<MatT>());
} else { } else {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>()); if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
TypeName<MatT>(), TypeName(other.GetType()));
}
} }
} }
MatPtrT& operator=(const MatPtr& other) { MatPtrT& operator=(const MatPtr& other) {
@ -440,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
} }
} }
// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32.
template <class Func, typename... Args>
decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
Args&&... args) {
if (base->GetType() == Type::kF32) {
const MatPtrT<float> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
const MatPtrT<BF16> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
}
void CopyMat(const MatPtr& from, MatPtr& to); void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat); void ZeroInit(MatPtr& mat);

View File

@ -19,20 +19,51 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <algorithm> // std::sort
#include <cmath> #include <cmath>
#include <iostream>
#include "util/basics.h" // RngStream
#include "util/mat.h"
#include "hwy/base.h" #include "hwy/base.h"
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "hwy/nanobenchmark.h"
#include "hwy/stats.h" #include "hwy/stats.h"
#include "hwy/tests/test_util.h" // RandomState #include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports // IWYU pragma: end_exports
namespace gcpp { namespace gcpp {
// Excludes outliers; we might not have enough samples for a reliable mode.
HWY_INLINE double TrimmedMean(double* seconds, size_t num) {
std::sort(seconds, seconds + num);
double sum = 0;
int count = 0;
for (size_t i = num / 4; i < num / 2; ++i) {
sum += seconds[i];
count += 1;
}
HWY_DASSERT(num != 0);
return sum / count;
}
// Returns normalized value in [-1, 1).
HWY_INLINE float RandomFloat(RngStream& rng) {
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
const uint32_t representation = exp | (rng() & mantissa_mask);
const float f12 = hwy::BitCastScalar<float>(representation);
HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
HWY_DASSERT(-1.0f <= f && f < 1.0f);
return f;
}
// Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights) // Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights)
// using the central limit theorem. Avoid std::normal_distribution for // using the central limit theorem. Avoid std::normal_distribution for
// consistent cross-platform output. // consistent cross-platform output.
// TODO: use RngStream instead of RandomState.
HWY_INLINE double RandomGaussian(hwy::RandomState& rng) { HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
uint64_t sum = 0; uint64_t sum = 0;
constexpr int kReps = 40; constexpr int kReps = 40;
@ -71,6 +102,25 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) {
HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3)); HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3));
} }
template <typename T>
void FillMatPtrT(MatPtrT<T>& mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
}
}
}
template <typename T>
void PrintMatPtr(MatPtrT<T> mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
std::cerr << mat.Row(i)[j] << " ,";
}
std::cerr << std::endl;
}
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_

View File

@ -187,7 +187,9 @@ class NestedPools {
// functions below. // functions below.
class IndexRangePartition { class IndexRangePartition {
public: public:
IndexRangePartition() = default; // for MMPartitions explicit IndexRangePartition(size_t single_task)
: range_(0, single_task), task_size_(single_task), num_tasks_(1) {}
IndexRangePartition(const IndexRange& range, const size_t task_size) IndexRangePartition(const IndexRange& range, const size_t task_size)
: range_(range), task_size_(static_cast<uint32_t>(task_size)) { : range_(range), task_size_(static_cast<uint32_t>(task_size)) {
const uint32_t num = static_cast<uint32_t>(range.Num()); const uint32_t num = static_cast<uint32_t>(range.Num());
@ -262,43 +264,6 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
return IndexRangePartition(range, size); return IndexRangePartition(range, size);
} }
// Parallel-for over a single range. This takes care of translating the task
// index to a range.
template <class Func>
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
hwy::pool::Caller caller, const Func& func) {
const size_t num_tasks = get1.NumTasks();
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
const IndexRange range1 = get1.Range(task);
func(range1, thread);
});
}
// Parallel-for over the Cartesian product of the two sets of ranges. This
// combines their indices into a single 'task' so they can be executed by one
// `pool.Run`, which increases the amount of work available to workers and
// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with
// the two ranges and the thread index within `pool`.
template <class Func>
void ParallelizeTwoRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
hwy::ThreadPool& pool, hwy::pool::Caller caller,
const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
HWY_DASSERT(idx1 < get1.NumTasks());
HWY_DASSERT(idx2 < get2.NumTasks());
const IndexRange range1 = get1.Range(idx1);
const IndexRange range2 = get2.Range(idx2);
func(range1, range2, thread);
});
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1, const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
num_workers * 5, num_workers * 20}; num_workers * 5, num_workers * 20};
// Count tasks executed to ensure workers aren't optimized out. One per // Count tasks executed to ensure workers aren't optimized out.
// cache line to avoid false sharing. std::vector<uint64_t> counters(num_workers * kU64PerLine);
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t); uint64_t prev_total = 0; // avoids having to reset counters.
std::vector<size_t> counters(num_workers * kSizePerLine);
size_t prev_total = 0; // avoids having to reset counters.
hwy::RandomState rng; hwy::RandomState rng;
for (size_t rep = 0; rep < 500; ++rep) { for (size_t rep = 0; rep < 500; ++rep) {
@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
pool.Run(begin, end, [&](uint64_t task, size_t thread) { pool.Run(begin, end, [&](uint64_t task, size_t thread) {
HWY_ASSERT(begin <= task && task < end); HWY_ASSERT(begin <= task && task < end);
HWY_ASSERT(thread < num_workers); HWY_ASSERT(thread < num_workers);
counters[thread * kSizePerLine]++; counters[thread * kU64PerLine]++;
}); });
// Reduce count and ensure it matches the expected number of tasks. // Reduce count and ensure it matches the expected number of tasks.
size_t total = 0; uint64_t total = 0;
for (size_t i = 0; i < num_workers; ++i) { for (size_t i = 0; i < num_workers; ++i) {
total += counters[i * kSizePerLine]; total += counters[i * kU64PerLine];
} }
const size_t expected = end - begin; const size_t expected = end - begin;
HWY_ASSERT(total == prev_total + expected); HWY_ASSERT(total == prev_total + expected);
@ -100,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args)
BoundedSlice(args.skip_lps, args.max_lps)), BoundedSlice(args.skip_lps, args.max_lps)),
cache_info(topology), cache_info(topology),
allocator(topology, cache_info, args.bind != Tristate::kFalse), allocator(topology, cache_info, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) { pools(topology, allocator, args.max_threads, args.pin),
tensor_output(args.tensor_output) {
PROFILER_ZONE("Startup.ThreadingContext autotune"); PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePools(hwy::PoolWaitMode::kSpin, *this); TunePools(hwy::PoolWaitMode::kSpin, *this);
// kBlock is the default, hence set/tune it last. // kBlock is the default, hence set/tune it last.

View File

@ -23,6 +23,7 @@
#include <stdint.h> #include <stdint.h>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "io/io.h" // Path
#include "util/allocator.h" #include "util/allocator.h"
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
@ -37,7 +38,9 @@ namespace gcpp {
// Optional arguments for `ThreadingContext` from the command line. // Optional arguments for `ThreadingContext` from the command line.
class ThreadingArgs : public ArgsBase<ThreadingArgs> { class ThreadingArgs : public ArgsBase<ThreadingArgs> {
public: public:
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}
ThreadingArgs() { Init(); }; ThreadingArgs() { Init(); };
// For BoundedTopology: // For BoundedTopology:
@ -55,6 +58,8 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
Tristate pin; // pin threads? Tristate pin; // pin threads?
Tristate spin; // use spin waits? Tristate spin; // use spin waits?
Path tensor_output; // empty, or directory for tensor output
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
// These can be used to partition CPU packages/sockets and their // These can be used to partition CPU packages/sockets and their
@ -85,6 +90,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
visitor(bind, "bind", Tristate::kDefault, visitor(bind, "bind", Tristate::kDefault,
"Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2); "Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(tensor_output, "tensor_output", Path(),
"Empty, or directory for tensor output.", 2);
} }
}; };
@ -100,7 +108,7 @@ struct ThreadingContext {
// Returns a worker index compatible with those from `ParallelFor`, assuming // Returns a worker index compatible with those from `ParallelFor`, assuming
// the current thread is running on one thread per cluster, which happens // the current thread is running on one thread per cluster, which happens
// when `ParallelismStrategy` is `kAcrossClusters`. // when `Parallelism` is `kAcrossClusters`.
size_t Worker(size_t cluster_idx) const { size_t Worker(size_t cluster_idx) const {
return cluster_idx * pools.MaxWorkersPerCluster(); return cluster_idx * pools.MaxWorkersPerCluster();
} }
@ -124,13 +132,15 @@ struct ThreadingContext {
// Per-package/cluster/within cluster pools of threads, matching `topology`. // Per-package/cluster/within cluster pools of threads, matching `topology`.
NestedPools pools; NestedPools pools;
Path tensor_output; // used by `TensorStats::Notify`.
}; };
#define GCPP_ZONE(ctx, global_idx, zone_enum) \ #define GCPP_ZONE(ctx, global_idx, zone_enum) \
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum)) PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
// Describes the strategy for distributing parallel work across cores. // Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t { enum class Parallelism : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker` // Execute using a single-threaded loop on the calling thread. The `worker`
// index passed to the user's `Func` is unique across clusters. // index passed to the user's `Func` is unique across clusters.
kNone, kNone,
@ -154,56 +164,110 @@ enum class ParallelismStrategy : uint8_t {
kHierarchical, kHierarchical,
}; };
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // Helper functions used to implement `ParallelFor`, also reused in multiple
// over clusters of ONE package, then within each cluster. // places. User code should call `ParallelFor` instead, which accepts the more
// convenient `Callers` enum.
//
// These call `func(task, worker)` for each task in `[0, num_tasks)`.
// NOTE: the worker argument is actually the `cluster_idx`, so that `Func` can
// pass that to `ParallelForWithinCluster`.
template <class Func>
void ParallelForAcrossClusters(size_t num_tasks, ThreadingContext& ctx,
hwy::pool::Caller caller, const Func& func) {
ctx.pools.AllClusters().Run(
0, num_tasks, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
}
template <class Func>
void ParallelForWithinCluster(size_t num_tasks, ThreadingContext& ctx,
size_t cluster_idx, hwy::pool::Caller caller,
const Func& func) {
const size_t cluster_base = ctx.Worker(cluster_idx);
ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, cluster_base + worker);
});
}
// Calls `func(range, cluster_idx)`, for passing to `*WithinCluster`.
template <class Func>
void ParallelPartitionAcrossClusters(const IndexRange range,
size_t task_multiple, size_t inner_tasks,
ThreadingContext& ctx,
hwy::pool::Caller caller,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const IndexRangePartition ranges = StaticPartition(
range, ctx.pools.NumClusters() * inner_tasks, task_multiple);
ParallelForAcrossClusters(ranges.NumTasks(), ctx, caller,
[&](uint64_t task, size_t cluster_idx) {
func(ranges.Range(task), cluster_idx);
});
}
// Calls `func(range, worker)`.
template <class Func>
void ParallelPartitionWithinCluster(const IndexRange range,
size_t task_multiple, size_t inner_tasks,
ThreadingContext& ctx, size_t cluster_idx,
hwy::pool::Caller caller,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t num_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
const IndexRangePartition ranges =
StaticPartition(range, num_workers * inner_tasks, task_multiple);
ParallelForWithinCluster(
ranges.NumTasks(), ctx, cluster_idx, caller,
[&](uint64_t task, size_t worker) { func(ranges.Range(task), worker); });
}
// Parallelizes across clusters, then within each cluster.
template <class Func> template <class Func>
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx, void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
Callers callers, const Func& func) { Callers callers, const Func& func) {
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster. // If at most one task per cluster worker, run on a single cluster to avoid
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); // the expensive cross-cluster barrier.
const size_t num_clusters = all_clusters.NumWorkers(); {
hwy::ThreadPool& cluster = ctx.pools.Cluster(0); const size_t cluster_idx = 0;
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { const size_t cluster_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { if (HWY_UNLIKELY(num_tasks <= cluster_workers)) {
func(task, thread); return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
}); func);
}
} }
// Assign each cluster a sub-range. ParallelPartitionAcrossClusters(
const IndexRangePartition ranges = IndexRange(0, num_tasks), /*task_multiple=*/1, /*inner_tasks=*/1, ctx,
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); caller, [&](const IndexRange& cluster_range, size_t cluster_idx) {
ParallelizeOneRange(ranges, all_clusters, caller, ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, caller,
[&](const IndexRange& range, const size_t cluster_idx) { [&](uint64_t i, size_t worker) {
hwy::ThreadPool& cluster = func(cluster_range.begin() + i, worker);
ctx.pools.Cluster(cluster_idx); });
const size_t cluster_base = });
cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(), caller,
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
} }
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for // number/type of workers determined by `parallelism`. NOTE: worker is actually
// `parallelism == kWithinCluster`, and should be 0 if unknown. // `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
template <class Func> template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, void ParallelFor(Parallelism parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, Callers callers, ThreadingContext& ctx, size_t cluster_idx, Callers callers,
const Func& func) { const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) { if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes. // If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone || HWY_DASSERT(parallelism == Parallelism::kNone ||
parallelism == ParallelismStrategy::kWithinCluster); parallelism == Parallelism::kWithinCluster);
} }
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
switch (parallelism) { switch (parallelism) {
case ParallelismStrategy::kNone: { case Parallelism::kNone: {
const size_t worker = ctx.Worker(cluster_idx); const size_t worker = ctx.Worker(cluster_idx);
for (size_t task = 0; task < num_tasks; ++task) { for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker); func(task, worker);
@ -211,40 +275,28 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
return; return;
} }
case ParallelismStrategy::kAcrossClusters: case Parallelism::kAcrossClusters:
return ctx.pools.AllClusters().Run( return ParallelForAcrossClusters(
0, num_tasks, caller, num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: { case Parallelism::kWithinCluster:
// Ensure the worker argument is unique across clusters, because it is return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
// used for TLS indexing for example in profiler.h. func);
const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, base + worker);
});
}
case ParallelismStrategy::kFlat: { case Parallelism::kFlat:
// Check for single cluster; if not, we must compute `cluster_base` for // Choose a single pool: the only cluster, or across all clusters
// consistent and non-overlapping worker indices. // (slower synchronization, but more memory bandwidth)
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
const size_t num_clusters = all_clusters.NumWorkers(); return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
if (num_clusters == 1) { func);
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller,
[&](uint64_t task, size_t worker) { func(task, worker); });
} }
return ParallelForAcrossClusters(num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) {
func(task, ctx.Worker(cluster_idx));
});
return all_clusters.Run(0, num_tasks, caller, case Parallelism::kHierarchical:
[&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
});
}
case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx, callers, func); return HierarchicalParallelFor(num_tasks, ctx, callers, func);
} }
} }

View File

@ -202,59 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) {
} }
} }
TEST(ThreadingTest, TestParallelizeOneRange) { static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine];
const IndexRange range(0, 10);
const IndexRangePartition partition = StaticPartition(range, 2, 4);
hwy::ThreadPool null_pool(0);
size_t calls = 0;
ParallelizeOneRange(partition, null_pool, kCaller,
[&](const IndexRange& range, size_t) {
if (++calls == 1) {
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
} else {
HWY_ASSERT(range.begin() == 8 && range.end() == 10);
}
});
HWY_ASSERT(calls == 2);
}
TEST(ThreadingTest, TestParallelizeTwoRanges) {
const IndexRangePartition partition1 =
StaticPartition(IndexRange(0, 10), 2, 4);
const IndexRangePartition partition2 =
MaxSizePartition(IndexRange(128, 256), 32, 32);
HWY_ASSERT(partition2.NumTasks() == 4);
hwy::ThreadPool null_pool(0);
{
size_t calls = 0;
ParallelizeTwoRanges(
partition1, partition2, null_pool, kCaller,
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
// Also swap order to test Remainder() logic.
{
size_t calls = 0;
ParallelizeTwoRanges(
partition2, partition1, null_pool, kCaller,
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
}
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t);
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) { std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
// Governs duration of test; avoid timeout in debug builds. // Governs duration of test; avoid timeout in debug builds.
@ -268,7 +216,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
for (size_t reps = 0; reps < 1200; ++reps) { for (size_t reps = 0; reps < 1200; ++reps) {
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
hwy::PreventElision(outputs[base]); hwy::PreventElision(outputs[base]);
if (pool.AutoTuneComplete()) break; if (pool.AutoTuneComplete()) break;
@ -309,7 +257,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), kCaller, pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
const uint64_t t1 = hwy::timer::Stop(); const uint64_t t1 = hwy::timer::Stop();
times.push_back(t1 - t0); times.push_back(t1 - t0);
@ -319,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), kCaller, pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
const uint64_t t1 = hwy::timer::Start(); const uint64_t t1 = hwy::timer::Start();
times.push_back(t1 - t0); times.push_back(t1 - t0);
@ -366,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) {
// Verify outputs to ensure the measured code is not a no-op. // Verify outputs to ensure the measured code is not a no-op.
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) { for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
HWY_ASSERT(outputs[lp * kU64PerThread] >= 1); HWY_ASSERT(outputs[lp * kU64PerLine] >= 1);
HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers()); HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers());
for (size_t i = 1; i < kU64PerThread; ++i) { for (size_t i = 1; i < kU64PerLine; ++i) {
HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0); HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0);
} }
} }
}; };

View File

@ -21,12 +21,14 @@
#include <vector> #include <vector>
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h"
namespace gcpp { namespace gcpp {
// Returns set of LPs available for use. // Returns set of LPs available for use.
static LPS EnabledLPs(const BoundedSlice& lp_slice) { static LPS EnabledLPs(const BoundedSlice& lp_slice) {
LPS enabled_lps; LPS enabled_lps;
const size_t num_lps = hwy::TotalLogicalProcessors();
// Thread-safe caching during the first call because subsequent pinning // Thread-safe caching during the first call because subsequent pinning
// overwrites the main thread's affinity. // overwrites the main thread's affinity.
@ -35,6 +37,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
if (!GetThreadAffinity(affinity)) affinity = LPS(); if (!GetThreadAffinity(affinity)) affinity = LPS();
return affinity; return affinity;
}(); }();
if (HWY_LIKELY(affinity.Any())) { if (HWY_LIKELY(affinity.Any())) {
// To honor taskset/numactl *and* the users's `lp_slice`, we interpret // To honor taskset/numactl *and* the users's `lp_slice`, we interpret
// the latter as a slice of the 1-bits of `enabled_lps`. Note that this // the latter as a slice of the 1-bits of `enabled_lps`. Note that this
@ -48,18 +51,32 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
} }
++enabled_idx; ++enabled_idx;
}); });
} else { }
const size_t num_lps = hwy::TotalLogicalProcessors();
// Do not warn on Apple, where affinity is not supported. if (HWY_UNLIKELY(!enabled_lps.Any())) {
if (!HWY_OS_APPLE) { // First warn: either about unknown affinity, or no overlap with `lp_slice`.
HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps, if (!affinity.Any()) {
lp_slice.Num(num_lps)); // Do not warn on Apple, where affinity is not supported.
if (!HWY_OS_APPLE) {
HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps,
lp_slice.Num(num_lps));
}
} else {
HWY_WARN("LP slice [%zu, %zu) of initial affinity %zu is empty.",
lp_slice.Begin(), lp_slice.End(num_lps), affinity.Count());
} }
// Set `enabled_lps` based only on `lp_slice` and total logical processors.
for (size_t lp = 0; lp < num_lps; ++lp) { for (size_t lp = 0; lp < num_lps; ++lp) {
if (lp_slice.Contains(num_lps, lp)) { if (lp_slice.Contains(num_lps, lp)) {
enabled_lps.Set(lp); enabled_lps.Set(lp);
} }
} }
if (!enabled_lps.Any()) {
HWY_WARN("no enabled LPs of total %zu, slice [%zu, %zu).", num_lps,
lp_slice.Begin(), lp_slice.End(affinity.Count()));
}
} }
// Without threading support, only keep the first enabled LP; it might still // Without threading support, only keep the first enabled LP; it might still
@ -72,6 +89,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
HWY_WARN("Warning, threads not supported, using only the main thread."); HWY_WARN("Warning, threads not supported, using only the main thread.");
} }
HWY_ASSERT(enabled_lps.Any());
return enabled_lps; return enabled_lps;
} }
@ -156,12 +174,13 @@ constexpr size_t kMaxLPsPerCluster = 6;
#if !GEMMA_DISABLE_TOPOLOGY #if !GEMMA_DISABLE_TOPOLOGY
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) { // Returns number of distinct SMT (hyperthreads).
LPS cores; static size_t NumSMT(const hwy::Topology& topology) {
lps.Foreach([&](size_t lp) { hwy::BitSet64 smt;
if (topology.lps[lp].smt == 0) cores.Set(lp); for (const hwy::Topology::LP& lp : topology.lps) {
}); smt.Set(lp.smt);
return cores.Count(); }
return smt.Count();
} }
// tcluster is a modifiable copy of the first cluster in the package. // tcluster is a modifiable copy of the first cluster in the package.
@ -187,34 +206,66 @@ void BoundedTopology::SplitLargeCluster(const LPS& enabled_lps,
} }
} }
// Main part of ctor, called when topology is known. using TClusters = std::vector<hwy::Topology::Cluster>;
bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
const size_t tpkg_idx = package_slice_.Begin(); // Returns false if no cluster in `tclusters` has any enabled LPs.
HWY_ASSERT(tpkg_idx < topology_.packages.size()); static bool AnyEnabledLPs(const TClusters& tclusters, const LPS& enabled_lps) {
const hwy::Topology::Package& tpackage = topology_.packages[tpkg_idx];
const std::vector<hwy::Topology::Cluster>& tclusters = tpackage.clusters;
if (HWY_UNLIKELY(tclusters.empty())) { if (HWY_UNLIKELY(tclusters.empty())) {
HWY_WARN("Topology: no clusters found in package %zu.", tpkg_idx); HWY_WARN("Topology: no clusters found.");
return false; return false;
} }
size_t max_tcluster_cores = 0;
size_t max_tcluster_lps = 0;
for (const hwy::Topology::Cluster& tcluster : tclusters) { for (const hwy::Topology::Cluster& tcluster : tclusters) {
const size_t cores = CoresFromLPs(tcluster.lps, topology_); bool any_lp_enabled = false;
const size_t lps = tcluster.lps.Count(); tcluster.lps.Foreach(
max_tcluster_cores = HWY_MAX(max_tcluster_cores, cores); [&](size_t lp) { any_lp_enabled |= (enabled_lps.Get(lp)); });
max_tcluster_lps = HWY_MAX(max_tcluster_lps, lps); if (any_lp_enabled) return true;
} }
HWY_ASSERT(max_tcluster_cores != 0);
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores); // No warning: this can happen if OS affinity limits us to the second package.
return false;
}
// Returns nullptr on failure. Also attempts `1 - tpkg_idx`, which is suitable
// for the common case of up to two packages.
static const TClusters* GetPackageClusters(const hwy::Topology& topology,
size_t tpkg_idx,
const LPS& enabled_lps) {
const size_t num_packages = topology.packages.size();
HWY_ASSERT(tpkg_idx < num_packages);
{
const TClusters& tclusters = topology.packages[tpkg_idx].clusters;
if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters;
}
// Retry with the other package, if any.
tpkg_idx ^= 1;
if (tpkg_idx == num_packages) return nullptr;
{
const TClusters& tclusters = topology.packages[tpkg_idx].clusters;
if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters;
}
HWY_WARN(
"Ignoring topology (%zu tpackages) because no clusters overlap with the "
"OS affinity (%zu enabled LPs): ",
num_packages, enabled_lps.Count());
enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); });
return nullptr;
}
// Main part of ctor, called when topology is known.
bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
const TClusters* maybe_tclusters =
GetPackageClusters(topology_, package_slice_.Begin(), enabled_lps);
if (!maybe_tclusters) return false;
const TClusters& tclusters = *maybe_tclusters;
// Populate `clusters` with the subset of clusters in `cluster_slice` that // Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. // have any enabled LPs.
clusters_.reserve(cluster_slice_.Num(tclusters.size())); clusters_.reserve(cluster_slice_.Num(tclusters.size()));
cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) { cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; Cluster cluster(enabled_lps, topology_.lps, tclusters[cluster_idx]);
Cluster cluster(enabled_lps, topology_.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`. // Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.NumWorkers() != 0)) { if (HWY_LIKELY(cluster.NumWorkers() != 0)) {
@ -223,14 +274,10 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
nodes_.Set(cluster.Node()); nodes_.Set(cluster.Node());
} }
}); });
if (HWY_UNLIKELY(clusters_.empty())) {
HWY_WARN("Too restrictive cluster_slice or enabled_lps, no clusters left.");
return false;
}
if (kSplitLargeClusters && clusters_.size() == 1 && if (kSplitLargeClusters && clusters_.size() == 1 &&
enabled_lps.Count() >= 16) { enabled_lps.Count() >= 16) {
SplitLargeCluster(enabled_lps, tpackage.clusters[0]); SplitLargeCluster(enabled_lps, tclusters[0]);
} }
// Sort by descending 'size' so that users who only use one get the largest. // Sort by descending 'size' so that users who only use one get the largest.
@ -239,20 +286,23 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
return a.NumWorkers() > b.NumWorkers(); return a.NumWorkers() > b.NumWorkers();
}); });
// Largest number of enabled workers in any cluster, for `topology_string_`. // Happens if all LPs are HTs (we checked that at least some LPs are enabled).
// This may be less than `max_tcluster_cores` if `enabled_lps` excludes some. if (HWY_UNLIKELY(clusters_.empty())) {
size_t max_cluster_workers = 0; HWY_WARN(
for (const Cluster& c : clusters_) { "Ignoring topology - no usable clusters. cluster_slice [%zu, %zu), "
max_cluster_workers = HWY_MAX(max_cluster_workers, c.NumWorkers()); "%zu tclusters, %zu tLPs, %zu enabled LPs: ",
cluster_slice_.Begin(), cluster_slice_.End(tclusters.size()),
tclusters.size(), topology_.lps.size(), enabled_lps.Count());
enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); });
return false;
} }
HWY_ASSERT(max_cluster_workers <= max_tcluster_cores);
// Do not warn about large clusters: GNR has 40.
const size_t num_smt = NumSMT(topology_);
snprintf(topology_string_, sizeof(topology_string_), snprintf(topology_string_, sizeof(topology_string_),
"%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)", "%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)",
topology_.packages.size(), tclusters.size(), max_tcluster_cores, topology_.packages.size(), tclusters.size(),
max_tcluster_lps / max_tcluster_cores, NumClusters(), tclusters[0].lps.Count() / num_smt, num_smt, NumClusters(),
max_cluster_workers, nodes_.Count()); clusters_[0].NumWorkers(), nodes_.Count());
return true; return true;
} }

View File

@ -93,7 +93,7 @@ class BoundedTopology {
class Cluster { class Cluster {
public: public:
Cluster(const LPS& lps); explicit Cluster(const LPS& lps);
Cluster(const LPS& enabled_lps, Cluster(const LPS& enabled_lps,
const std::vector<hwy::Topology::LP>& all_lps, const std::vector<hwy::Topology::LP>& all_lps,
const hwy::Topology::Cluster& tcluster); const hwy::Topology::Cluster& tcluster);

View File

@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) {
return "Gen.SampleTop1"; return "Gen.SampleTop1";
case Zones::kGenSampleTopK: case Zones::kGenSampleTopK:
return "Gen.SampleTopK"; return "Gen.SampleTopK";
case Zones::kGenStats:
return "Gen.Stats";
case Zones::kMMDecompressA: case Zones::kMMDecompressA:
return "MM.DecompressA"; return "MM.DecompressA";
case Zones::kMMDispatch: case Zones::kMMDispatch:
@ -163,6 +165,8 @@ const char* CallerName(Callers caller) {
return "ReadBatches"; return "ReadBatches";
case Callers::kSampleAndStream: case Callers::kSampleAndStream:
return "SampleAndStream"; return "SampleAndStream";
case Callers::kTensorStats:
return "TensorStats";
case Callers::kTest: // only for unit tests. case Callers::kTest: // only for unit tests.
return "Test-only!"; return "Test-only!";
case Callers::kTunePool: case Callers::kTunePool:

View File

@ -31,6 +31,7 @@ enum class Zones { // Keep sorted
kGenFFW, kGenFFW,
kGenSampleTop1, kGenSampleTop1,
kGenSampleTopK, kGenSampleTopK,
kGenStats,
kMMDecompressA, kMMDecompressA,
kMMDispatch, kMMDispatch,
kMMMatMul, kMMMatMul,
@ -96,6 +97,7 @@ enum class Callers { // Keep sorted
kReadAllToBF16, kReadAllToBF16,
kReadBatches, kReadBatches,
kSampleAndStream, kSampleAndStream,
kTensorStats,
kTest, // only for unit tests. kTest, // only for unit tests.
kTunePool, kTunePool,
kVitDotSoftmax1, kVitDotSoftmax1,