mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into upgrade-github-actions-node24
This commit is contained in:
commit
a4c78d4454
76
BUILD.bazel
76
BUILD.bazel
|
|
@ -66,6 +66,7 @@ cc_library(
|
|||
srcs = ["util/topology.cc"],
|
||||
hdrs = ["util/topology.h"],
|
||||
deps = [
|
||||
"@highway//:bit_set",
|
||||
"@highway//:hwy",
|
||||
"@highway//:topology",
|
||||
],
|
||||
|
|
@ -111,6 +112,7 @@ cc_library(
|
|||
":threading",
|
||||
":topology",
|
||||
":zones",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:profiler",
|
||||
|
|
@ -139,7 +141,7 @@ cc_test(
|
|||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
|
|
@ -172,8 +174,11 @@ cc_library(
|
|||
name = "test_util",
|
||||
hdrs = ["util/test_util.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":mat",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:stats",
|
||||
],
|
||||
)
|
||||
|
|
@ -440,9 +445,9 @@ cc_test(
|
|||
":gemma_lib",
|
||||
":mat",
|
||||
":ops",
|
||||
":query",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":zones",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:test_util",
|
||||
"//compression:types",
|
||||
|
|
@ -519,32 +524,70 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "kv_cache_test",
|
||||
srcs = ["gemma/kv_cache_test.cc"],
|
||||
deps = [
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":threading_context",
|
||||
"//testing/base/public:gunit_main",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "query",
|
||||
hdrs = ["gemma/query.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_args",
|
||||
hdrs = ["gemma/gemma_args.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
":threading_context",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_args_test",
|
||||
srcs = ["gemma/gemma_args_test.cc"],
|
||||
deps = [
|
||||
":gemma_args",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/attention.cc",
|
||||
"gemma/flash_attention.cc",
|
||||
"gemma/gemma.cc",
|
||||
"gemma/tensor_stats.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/attention.h",
|
||||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/tensor_stats.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
|
|
@ -555,6 +598,7 @@ cc_library(
|
|||
"gemma/gemma-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
|
|
@ -564,6 +608,7 @@ cc_library(
|
|||
":matmul_env",
|
||||
":model_store",
|
||||
":ops",
|
||||
":query",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":weights",
|
||||
|
|
@ -577,8 +622,34 @@ cc_library(
|
|||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
] +
|
||||
[
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_lib_test",
|
||||
srcs = ["gemma/attention_test.cc"],
|
||||
# MatMulEnvs are up to 20GB large.
|
||||
tags = ["requires-mem:28g"],
|
||||
deps = [
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":matmul_env",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -604,7 +675,6 @@ cc_library(
|
|||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":matmul_env",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
@ -82,6 +82,7 @@ set(SOURCES
|
|||
gemma/configs.h
|
||||
gemma/flash_attention.cc
|
||||
gemma/flash_attention.h
|
||||
gemma/flash_structs.h
|
||||
gemma/gemma_args.h
|
||||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
|
|
@ -221,6 +222,7 @@ set(GEMMA_TEST_FILES
|
|||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
gemma/gemma_args_test.cc
|
||||
gemma/flash_attention_test.cc
|
||||
gemma/tensor_info_test.cc
|
||||
io/blob_store_test.cc
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579",
|
||||
commit = "3b680cde3a556bead9cc23c8f595d07a44d5a0d5",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
13
README.md
13
README.md
|
|
@ -55,7 +55,6 @@ Guidelines](https://opensource.google.com/conduct/).
|
|||
|
||||
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
|
||||
- Sampling with TopK and temperature.
|
||||
- Backward pass (VJP) and Adam optimizer for Gemma research.
|
||||
|
||||
- Optimizations
|
||||
|
||||
|
|
@ -452,7 +451,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
|
|
@ -520,13 +519,19 @@ Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
|||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
|
||||
|
||||
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||
people.
|
||||
Gemma 2 support was implemented in June/July 2024 with the help of several
|
||||
people including Daniel Keysers and Phil Culliton.
|
||||
|
||||
PaliGemma support was implemented in September 2024 with contributions from
|
||||
Daniel Keysers.
|
||||
|
||||
Gemma 3 support was implemented in January-March 2025 with contributions from
|
||||
Daniel Keysers and Phil Culliton.
|
||||
|
||||
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
|
||||
improvements, including major gains in efficiency, since the initial release.
|
||||
|
||||
[Phil Culliton](mailto:philculliton@google.com) has worked on model releases,
|
||||
eval and validation, GTM, and quantization, since the initial release.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Weight compression and analysis.
|
||||
# Compressed tensor types.
|
||||
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
load("@rules_cc//cc:cc_test.bzl", "cc_test")
|
||||
|
|
@ -101,13 +101,11 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":int",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -135,8 +133,8 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":compress",
|
||||
":distortion",
|
||||
":sfp",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -182,7 +180,6 @@ cc_library(
|
|||
"//:mat",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -209,19 +206,3 @@ cc_test(
|
|||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
# For internal experimentation
|
||||
cc_library(
|
||||
name = "analyze",
|
||||
textual_hdrs = ["analyze.h"],
|
||||
deps = [
|
||||
":int",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,238 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Normal include guard to placate lint.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <cmath> // std::signbit
|
||||
#include <cstdlib> // std::abs
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "compression/nuq-inl.h"
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
class PerThread {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
constexpr size_t kGroupSize = NuqStream::kGroupSize;
|
||||
hwy::Stats s_group;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
if (group[i] == 0.0f || group[i] == -0.0f) continue;
|
||||
s_all_.Notify(group[i]);
|
||||
s_group.Notify(group[i]);
|
||||
|
||||
num_tiny_ += std::abs(group[i]) < 1e-3f;
|
||||
|
||||
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
|
||||
const uint32_t binary32 =
|
||||
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
|
||||
|
||||
// const int32_t exp = (binary32 >> 23) - 127;
|
||||
b_exp256_.Notify(binary32 >> 23);
|
||||
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
|
||||
b_m4_.Notify(m4);
|
||||
}
|
||||
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
|
||||
s_group_mins_.Notify(s_group.Min());
|
||||
s_group_maxs_.Notify(s_group.Max());
|
||||
|
||||
float desc[kGroupSize];
|
||||
memcpy(desc, group, kGroupSize * sizeof(group[0]));
|
||||
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
|
||||
|
||||
// Find largest |max/min| (dynamic range)
|
||||
float max_ratio = 0.0f;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
if (desc[i] != 0.0f && desc[i] != -0.0f) {
|
||||
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
|
||||
}
|
||||
}
|
||||
s_group_max_vs_min_.Notify(max_ratio);
|
||||
|
||||
// Relative errors
|
||||
float diffs[kGroupSize];
|
||||
for (size_t i = 0; i < kGroupSize - 1; ++i) {
|
||||
// was in descending order. Avoid div by 0. Ignore sign changes.
|
||||
diffs[i] = std::abs(desc[i]) < 1e-5
|
||||
? 0
|
||||
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
|
||||
}
|
||||
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
|
||||
s_cut15_.Notify(diffs[15]);
|
||||
}
|
||||
|
||||
void Assimilate(const PerThread& other) {
|
||||
num_tiny_ += other.num_tiny_;
|
||||
s_all_.Assimilate(other.s_all_);
|
||||
s_group_ranges_.Assimilate(other.s_group_ranges_);
|
||||
s_group_mins_.Assimilate(other.s_group_mins_);
|
||||
s_group_maxs_.Assimilate(other.s_group_maxs_);
|
||||
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
|
||||
s_erange_.Assimilate(other.s_erange_);
|
||||
s_km_1_.Assimilate(other.s_km_1_);
|
||||
s_km_2_.Assimilate(other.s_km_2_);
|
||||
s_cut15_.Assimilate(other.s_cut15_);
|
||||
b_magn100_.Assimilate(other.b_magn100_);
|
||||
b_exp256_.Assimilate(other.b_exp256_);
|
||||
b_m4_.Assimilate(other.b_m4_);
|
||||
}
|
||||
|
||||
void PrintAll() {
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "num tiny %zu\n", num_tiny_);
|
||||
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
|
||||
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
|
||||
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
|
||||
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
|
||||
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
|
||||
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
|
||||
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
|
||||
|
||||
// b_magn100_.Print("magn100");
|
||||
// b_exp256_.Print("exp");
|
||||
// b_m4_.Print("mantissa bits4");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_tiny_ = 0;
|
||||
hwy::Stats s_all_;
|
||||
hwy::Stats s_group_ranges_;
|
||||
hwy::Stats s_group_mins_;
|
||||
hwy::Stats s_group_maxs_;
|
||||
hwy::Stats s_group_max_vs_min_;
|
||||
hwy::Stats s_erange_;
|
||||
hwy::Stats s_km_1_;
|
||||
hwy::Stats s_km_2_;
|
||||
hwy::Stats s_cut15_;
|
||||
hwy::Bins<100> b_magn100_;
|
||||
hwy::Bins<256> b_exp256_;
|
||||
hwy::Bins<16> b_m4_;
|
||||
uint8_t padding_[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
class PerLayer {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
for (size_t i = 0; i < NuqStream::kGroupSize; ++i) {
|
||||
s_layer_.Notify(group[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
|
||||
const float layer_mean = s_layer_.Mean();
|
||||
const float layer_sd = s_layer_.StandardDeviation();
|
||||
for (size_t i = 0; i < weights_per_layer; ++i) {
|
||||
num_outliers_ +=
|
||||
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
|
||||
}
|
||||
}
|
||||
|
||||
const hwy::Stats& GetStats() const { return s_layer_; }
|
||||
size_t Outliers() const { return num_outliers_; }
|
||||
|
||||
private:
|
||||
hwy::Stats s_layer_;
|
||||
size_t num_outliers_ = 0;
|
||||
uint8_t padding[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
||||
size_t weights_per_layer,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::vector<PerThread> tls;
|
||||
std::vector<PerLayer> per_layer(layers);
|
||||
const auto init = [&](size_t num_threads) {
|
||||
tls.resize(num_threads);
|
||||
return true;
|
||||
};
|
||||
|
||||
pool.Run(0, static_cast<uint32_t>(layers), init,
|
||||
[&](uint32_t idx_layer, size_t idx_thread) {
|
||||
PerThread& self = tls[idx_thread];
|
||||
const float* layer = &mat[idx_layer * weights_per_layer];
|
||||
// For each whole group in the layer
|
||||
for (size_t group_start = 0;
|
||||
group_start + NuqStream::kGroupSize <= weights_per_layer;
|
||||
group_start += NuqStream::kGroupSize) {
|
||||
const float* group = layer + group_start;
|
||||
per_layer[idx_layer].NotifyGroup(group);
|
||||
self.NotifyGroup(group);
|
||||
}
|
||||
|
||||
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
|
||||
});
|
||||
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "\n------------%s\n", caption);
|
||||
|
||||
for (size_t i = 1; i < pool.NumWorkers(); ++i) {
|
||||
tls[0].Assimilate(tls[i]);
|
||||
}
|
||||
tls[0].PrintAll();
|
||||
|
||||
hwy::Stats s_layer_ranges;
|
||||
hwy::Stats s_layer_outliers;
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
fprintf(stderr, " %02zu %s\n", i,
|
||||
per_layer[i].GetStats().ToString(skip).c_str());
|
||||
const float range =
|
||||
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
|
||||
s_layer_ranges.Notify(range);
|
||||
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
|
||||
weights_per_layer);
|
||||
}
|
||||
fprintf(stderr, "layer outliers%% %s\n",
|
||||
s_layer_outliers.ToString(skip).c_str());
|
||||
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
|
@ -82,6 +82,8 @@ struct CompressTraits<float> {
|
|||
hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF);
|
||||
}
|
||||
|
||||
static float ToFloatSlow(const Packed x) { return x; }
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
|
||||
static HWY_INLINE void Load2(DBF16 dbf16,
|
||||
const PackedSpan<const Packed>& packed,
|
||||
|
|
@ -254,6 +256,10 @@ struct CompressTraits<BF16> {
|
|||
packed.ptr + packed_ofs);
|
||||
}
|
||||
|
||||
static float ToFloatSlow(const Packed x) {
|
||||
return hwy::ConvertScalarTo<float>(x);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16)>
|
||||
static HWY_INLINE void Load2(DBF16 dbf16,
|
||||
const PackedSpan<const Packed>& packed,
|
||||
|
|
@ -397,6 +403,27 @@ struct CompressTraits<SfpStream> {
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE: this does not take into account the per-tensor scale.
|
||||
static float ToFloatSlow(const Packed x) {
|
||||
uint32_t sfp = x.byte;
|
||||
HWY_ASSERT(sfp != 0x80); // -0 is reserved
|
||||
|
||||
const uint32_t sign32 = (sfp & 0x80) << 24;
|
||||
sfp &= 0x7F;
|
||||
const bool large_e = sfp >= 64;
|
||||
const size_t m_bits = large_e ? 3 : 2;
|
||||
uint32_t m = sfp & ((1u << m_bits) - 1u);
|
||||
size_t e = sfp >> m_bits;
|
||||
if (sfp == 0) return 0.0f;
|
||||
const uint32_t e_bias = large_e ? 15 : 23;
|
||||
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
|
||||
const uint32_t mnt32 = m << (23 - m_bits);
|
||||
const uint32_t binary32 = sign32 | exp32 | mnt32;
|
||||
float result;
|
||||
hwy::CopySameSize(&binary32, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class D> // Caller checks this is f32 or bf16
|
||||
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||
|
|
@ -437,6 +464,12 @@ struct CompressTraits<I8Stream> {
|
|||
IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
||||
}
|
||||
|
||||
static float ToFloatSlow(const Packed x) {
|
||||
HWY_DASSERT(!"Not supported - requires a stream");
|
||||
return 0.0f;
|
||||
}
|
||||
// Store2 is not yet implemented.
|
||||
|
||||
template <class D, typename Raw>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
|
|
@ -483,6 +516,10 @@ struct CompressTraits<NuqStream> {
|
|||
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
||||
}
|
||||
|
||||
static float ToFloatSlow(const Packed x) {
|
||||
HWY_DASSERT(!"Not supported - requires a stream");
|
||||
return 0.0f;
|
||||
}
|
||||
// Store2 is not yet implemented.
|
||||
|
||||
template <class D, typename Raw>
|
||||
|
|
@ -604,6 +641,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
|||
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
||||
}
|
||||
|
||||
// NOTE: the following are the recommended way to iterate over arrays of
|
||||
// potentially compressed elements, including remainder handling. Prefer them
|
||||
// over calling `Decompress2` directly, which does not handle remainders.
|
||||
// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
|
||||
// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
|
||||
// user code expressed as lambdas.
|
||||
|
||||
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
|
||||
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
|
||||
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
|
||||
|
|
@ -733,8 +777,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
|||
comp3);
|
||||
}
|
||||
|
||||
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
|
||||
// `DF` is the decompressed type, typically `float`.
|
||||
// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`.
|
||||
// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
|
||||
template <class DF, typename T, class Func>
|
||||
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||
size_t num, Func&& func) {
|
||||
|
|
@ -773,6 +817,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
|||
}
|
||||
|
||||
// One extra argument. `DF` is the decompressed type, typically `float`.
|
||||
// Calls `func(df, v_inout, v1)`.
|
||||
template <class DF, typename T, typename T1, class Func>
|
||||
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||
size_t num,
|
||||
|
|
@ -821,8 +866,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
|||
}
|
||||
}
|
||||
|
||||
// Two extra arguments. `DF` is the decompressed type, typically `float`.
|
||||
// Calls `func(df, v_inout, v1, v2)`.
|
||||
template <class DF, typename T, typename T1, typename T2, class Func>
|
||||
HWY_INLINE void Decompress2AndCompressInplace(
|
||||
DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
|
||||
const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
|
||||
const auto packed_inout = MakeSpan(inout, num);
|
||||
const auto packed1 = MakeSpan(p1, num);
|
||||
const auto packed2 = MakeSpan(p2, p2_ofs + num);
|
||||
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
VF v0, v1;
|
||||
Decompress2(df, packed_inout, i, v0, v1);
|
||||
VF v10, v11;
|
||||
Decompress2(df, packed1, i, v10, v11);
|
||||
VF v20, v21;
|
||||
Decompress2(df, packed2, p2_ofs + i, v20, v21);
|
||||
const VF out0 = func(df, v0, v10, v20);
|
||||
const VF out1 = func(df, v1, v11, v21);
|
||||
Compress2(df, out0, out1, packed_inout, i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_inout + NF);
|
||||
hn::Store(hn::Zero(df), df, buf1 + NF);
|
||||
hn::Store(hn::Zero(df), df, buf2 + NF);
|
||||
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
|
||||
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
|
||||
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
|
||||
const VF v0 = hn::Load(df, buf_inout);
|
||||
const VF v1 = hn::Load(df, buf_inout + NF);
|
||||
const VF v10 = hn::Load(df, buf1);
|
||||
const VF v11 = hn::Load(df, buf1 + NF);
|
||||
const VF v20 = hn::Load(df, buf2);
|
||||
const VF v21 = hn::Load(df, buf2 + NF);
|
||||
const VF out0 = func(df, v0, v10, v20);
|
||||
const VF out1 = func(df, v1, v11, v21);
|
||||
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
|
||||
// Clang generates incorrect code for CopyBytes if num = 2.
|
||||
for (size_t j = 0; j < remaining; ++j) {
|
||||
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Single input, separate output. `DF` is the decompressed type, typically
|
||||
// `float`.
|
||||
// `float`. Calls `func(df, v1)`.
|
||||
template <class DF, typename T, typename T1, class Func>
|
||||
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
const T1* HWY_RESTRICT p1,
|
||||
|
|
@ -863,7 +964,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Two inputs. `DF` is the decompressed type, typically `float`.
|
||||
// Two inputs, separate output. `DF` is the decompressed type, typically
|
||||
// `float`. Calls `func(df, v1, v2)`.
|
||||
template <class DF, typename T, typename T1, typename T2, class Func>
|
||||
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
const T1* HWY_RESTRICT p1,
|
||||
|
|
@ -912,7 +1014,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Three inputs. `DF` is the decompressed type, typically `float`.
|
||||
// Three inputs, separate output. `DF` is the decompressed type, typically
|
||||
// `float`. Calls `func(df, v1, v2, v3)`.
|
||||
template <class DF, typename T, typename T1, typename T2, typename T3,
|
||||
class Func>
|
||||
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
|
|
|
|||
|
|
@ -259,6 +259,13 @@ class TestDecompressAndCompress {
|
|||
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
|
||||
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
||||
|
||||
// `out` already contains v + v1.
|
||||
Decompress2AndCompressInplace(
|
||||
df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0,
|
||||
[](DF, VF v, VF /*v1*/, VF v2)
|
||||
HWY_ATTR -> VF { return hn::Add(v, v2); });
|
||||
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
|
||||
|
||||
Decompress1AndCompressTo(df, out.get(), num, p.get(),
|
||||
[](DF, VF v) HWY_ATTR -> VF { return v; });
|
||||
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
|
||||
|
|
|
|||
|
|
@ -480,9 +480,12 @@ class NibbleCodec {
|
|||
static_assert(kHalf <= 1);
|
||||
const size_t N = hn::Lanes(d8);
|
||||
constexpr size_t kMaxN = hn::MaxLanes(d8);
|
||||
constexpr bool kPermuteAcrossBlocks =
|
||||
HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86;
|
||||
// For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of
|
||||
// bounds for TableLookupBytes. We instead BroadcastBlock<1> there.
|
||||
constexpr uint8_t kAdd = kMaxN < 64 ? kHalf * kMaxN / 4 : 0;
|
||||
constexpr uint8_t kAdd =
|
||||
kMaxN < 64 || kPermuteAcrossBlocks ? kHalf * kMaxN / 4 : 0;
|
||||
// The only performance-portable op to replicate bytes is TableLookupBytes,
|
||||
// but this only works if vectors are 128-bit or we first BroadcastBlock,
|
||||
// which only works for <= 512-bit vectors. For scalable vectors, we
|
||||
|
|
@ -506,7 +509,7 @@ class NibbleCodec {
|
|||
} else if constexpr (kMaxN <= 16) { // <= 128-bit
|
||||
// No BroadcastBlock, we anyway only have one block.
|
||||
return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4));
|
||||
} else if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||
} else if constexpr (kPermuteAcrossBlocks) {
|
||||
// No BroadcastBlock, can directly permute across blocks.
|
||||
return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4));
|
||||
} else { // 256..512-bit, no efficient TableLookupLanes
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ cc_library(
|
|||
"//io",
|
||||
"//io:blob_store",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,37 +37,23 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Decode
|
||||
float F32FromSFP8(uint32_t sfp) {
|
||||
HWY_ASSERT(sfp < 256);
|
||||
HWY_ASSERT(sfp != 0x80); // -0 is reserved
|
||||
HWY_INLINE_VAR constexpr bool kPrint = false;
|
||||
|
||||
const uint32_t sign32 = (sfp & 0x80) << 24;
|
||||
sfp &= 0x7F;
|
||||
const bool large_e = sfp >= 64;
|
||||
const size_t m_bits = large_e ? 3 : 2;
|
||||
uint32_t m = sfp & ((1u << m_bits) - 1u);
|
||||
size_t e = sfp >> m_bits;
|
||||
if (sfp == 0) return 0.0f;
|
||||
const uint32_t e_bias = large_e ? 15 : 23;
|
||||
const uint32_t exp32 = static_cast<uint32_t>(127 + e - e_bias) << 23;
|
||||
const uint32_t mnt32 = m << (23 - m_bits);
|
||||
const uint32_t binary32 = sign32 | exp32 | mnt32;
|
||||
float result;
|
||||
hwy::CopySameSize(&binary32, &result);
|
||||
return result;
|
||||
static float F32FromSFP8(uint32_t sfp) {
|
||||
return CompressTraits<SfpStream>::ToFloatSlow(
|
||||
SfpStream{static_cast<uint8_t>(sfp)});
|
||||
}
|
||||
|
||||
// Used for HWY_AVX3_DL and newer.
|
||||
void PrintTables() {
|
||||
if (HWY_ONCE && false) {
|
||||
if (HWY_ONCE && kPrint) {
|
||||
uint8_t hi[128];
|
||||
fprintf(stderr, "lo\n");
|
||||
for (uint32_t sfp = 0; sfp < 128; ++sfp) {
|
||||
|
|
@ -92,7 +78,7 @@ void TestAllUnique() {
|
|||
unique.insert(F32FromSFP8(sfp));
|
||||
}
|
||||
HWY_ASSERT_EQ(size_t{255}, unique.size());
|
||||
if (false) {
|
||||
if (kPrint) {
|
||||
for (float f : unique) {
|
||||
fprintf(stderr, "%e\n", f);
|
||||
}
|
||||
|
|
@ -163,7 +149,7 @@ HWY_INLINE uint32_t SFP8FromF32(float f) {
|
|||
if (m == 0) m = 1;
|
||||
}
|
||||
|
||||
if (false) {
|
||||
if (kPrint) {
|
||||
fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n",
|
||||
org_binary32, round, rounded, e, m, large_e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
|||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
|
|
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ using json = nlohmann::json;
|
|||
|
||||
class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
||||
public:
|
||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path summarize_text;
|
||||
Path cross_entropy;
|
||||
|
|
@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
if (!benchmark_args.summarize_text.Empty()) {
|
||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||
|
|
|
|||
|
|
@ -36,30 +36,29 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||
GemmaEnv::GemmaEnv(const GemmaArgs& args)
|
||||
: initializer_value_(gcpp::InternalInit()),
|
||||
ctx_(args.threading),
|
||||
env_(ctx_),
|
||||
gemma_(args, ctx_) {
|
||||
const ModelConfig& config = gemma_.Config();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||
kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
||||
ctx_);
|
||||
if (args.inference.verbosity >= 2) {
|
||||
ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
|
||||
}
|
||||
if (args.inference.verbosity >= 3) env_.print_best = true;
|
||||
if (args.inference.verbosity >= 4) env_.print_config = true;
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.verbosity = inference.verbosity,
|
||||
.max_generated_tokens = args.inference.max_generated_tokens,
|
||||
.temperature = args.inference.temperature,
|
||||
.verbosity = args.inference.verbosity,
|
||||
};
|
||||
inference.CopyTo(runtime_config_);
|
||||
args.inference.CopyTo(runtime_config_);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
|
||||
InferenceArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
||||
|
|
@ -229,19 +228,19 @@ static constexpr const char* CompiledConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||
const WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx) {
|
||||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
fprintf(
|
||||
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
||||
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
|
||||
args.threading.Print(args.inference.verbosity);
|
||||
args.loader.Print(args.inference.verbosity);
|
||||
args.inference.Print(args.inference.verbosity);
|
||||
fprintf(stderr,
|
||||
"Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||
config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
|
||||
static_cast<int>(args.loader.map),
|
||||
WeightsPtrs::ToString(weight_read_mode));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
if (args.inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
|
|
@ -254,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s, profiler %d\n"
|
||||
"Memory MiB : %4zu\n",
|
||||
dt, cpu100, static_cast<int>(threading.bind),
|
||||
dt, cpu100, static_cast<int>(args.threading.bind),
|
||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
|
||||
|
|
@ -262,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
}
|
||||
}
|
||||
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
std::cerr
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights gemma2-2b-it-sfp.sbs\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/gemma_args.h" // IWYU pragma: export
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h"
|
||||
#include "util/threading_context.h"
|
||||
|
|
@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
|
|||
// Convenience class to load a model and run inference.
|
||||
class GemmaEnv {
|
||||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
explicit GemmaEnv(const GemmaArgs& args);
|
||||
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
size_t MaxGeneratedTokens() const {
|
||||
|
|
@ -125,6 +123,8 @@ class GemmaEnv {
|
|||
MatMulEnv& MutableEnv() { return env_; }
|
||||
|
||||
private:
|
||||
// This is used to ensure that InternalInit is called before anything else.
|
||||
int initializer_value_ = 0;
|
||||
ThreadingContext ctx_;
|
||||
MatMulEnv env_;
|
||||
Gemma gemma_;
|
||||
|
|
@ -135,12 +135,9 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||
WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx);
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
|
|||
->UseRealTime();
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
env.SetMaxGeneratedTokens(256);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ namespace gcpp {
|
|||
|
||||
class PromptArgs : public ArgsBase<PromptArgs> {
|
||||
public:
|
||||
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path layers_output; // optional
|
||||
std::string prompt;
|
||||
|
|
@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
|
|||
};
|
||||
|
||||
int Run(int argc, char** argv) {
|
||||
PromptArgs prompt_args(argc, argv);
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
const GemmaArgs args(argc, argv, consumed);
|
||||
const PromptArgs prompt_args(argc, argv, consumed);
|
||||
AbortIfInvalidArgs(prompt_args);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
json json_output;
|
||||
GemmaEnv env(argc, argv);
|
||||
GemmaEnv env(args);
|
||||
|
||||
env.MutableConfig().layers_output =
|
||||
prompt_args.layers_output.Empty()
|
||||
? LayersOutputFunc()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||
std::vector<std::string> GenerateInputs() {
|
||||
std::vector<std::string> prompts = {
|
||||
{"Describe dynamic programming."},
|
||||
{"Explain how electric cars work."},
|
||||
|
|
@ -122,33 +122,38 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
|||
inputs.push_back(prompts[qpos++]);
|
||||
if (qpos == prompts.size()) qpos = 0;
|
||||
}
|
||||
s_env->SetMaxGeneratedTokens(24);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||
s_env->SetMaxGeneratedTokens(12);
|
||||
const std::vector<std::string> inputs = GenerateInputs();
|
||||
|
||||
// Run multiple times so that auto-tuning is closer to complete.
|
||||
for (size_t rep = 0; rep < 4; ++rep) {
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||
++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
|
||||
responses[i].c_str());
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
|
||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||
// because those are already fast.
|
||||
s_env->SetMaxGeneratedTokens(2);
|
||||
responses = BatchGemmaReply(inputs);
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
fprintf(stderr, "GemmaEnv setup..\n");
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
|
|||
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
||||
static void InitEnv(int argc, char** argv) {
|
||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
s_env = new GemmaEnv(args);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||
}
|
||||
|
|
@ -130,7 +133,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||
// duplicated.
|
||||
mutable_prompt = "Please repeat all prior statements.";
|
||||
mutable_prompt = "Please repeat what I just told you.";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
|
|
@ -167,6 +170,9 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 1.30f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA3_270M:
|
||||
EXPECT_NEAR(entropy, 1.41f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
|
|
@ -178,7 +184,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
gcpp::InternalInit();
|
||||
gcpp::GemmaTest::InitEnv(argc, argv);
|
||||
int ret = RUN_ALL_TESTS();
|
||||
gcpp::GemmaTest::DeleteEnv();
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct JsonArgs : public ArgsBase<JsonArgs> {
|
||||
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path input;
|
||||
|
||||
|
|
@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
int main(int argc, char** argv) {
|
||||
{
|
||||
PROFILER_ZONE("Startup.all");
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::JsonArgs json(argc, argv);
|
||||
gcpp::AbortIfInvalidArgs(json);
|
||||
gcpp::Run(env, json);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::JsonArgs json_args(argc, argv, consumed);
|
||||
gcpp::AbortIfInvalidArgs(json_args);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::Run(env, json_args);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -24,20 +24,20 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h" // GemmaArgs
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/args.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||
std::set<int> reject_tokens;
|
||||
|
|
@ -49,10 +49,10 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::ThreadingContext ctx(threading);
|
||||
gcpp::ThreadingContext ctx(args.threading);
|
||||
gcpp::MatMulEnv env(ctx);
|
||||
gcpp::Gemma gemma(loader, inference, ctx);
|
||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
gcpp::Gemma gemma(args, ctx);
|
||||
gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
|
||||
size_t generated = 0;
|
||||
|
||||
// Tokenize instructions.
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs
|
||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||
|
|
@ -31,18 +31,11 @@
|
|||
|
||||
class SimplifiedGemma {
|
||||
public:
|
||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: ctx_(threading),
|
||||
SimplifiedGemma(const gcpp::GemmaArgs& args)
|
||||
: ctx_(args.threading),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_),
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||
gcpp::ThreadingArgs(argc, argv),
|
||||
gcpp::InferenceArgs(argc, argv)) {}
|
||||
gemma_(args, ctx_),
|
||||
kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {}
|
||||
|
||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||
float temperature = 0.7,
|
||||
|
|
|
|||
|
|
@ -18,26 +18,16 @@
|
|||
#include <string>
|
||||
|
||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||
// necessary flags.
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
// Sets arguments from argc and argv. Note that you can instead pass in
|
||||
// LoaderArgs, ThreadingArgs, and InferenceArgs directly.
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
||||
//
|
||||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||
// "model_identifier");
|
||||
|
||||
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
||||
// specified, default values will be used.
|
||||
//
|
||||
// gcpp::InferenceArgs inference(argc, argv);
|
||||
// gcpp::ThreadingArgs threading(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, threading, inference);
|
||||
|
||||
SimplifiedGemma gemma(loader);
|
||||
SimplifiedGemma gemma(args);
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
gemma.Generate(prompt, 256, 0.6);
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // AttentionImpl
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
|
@ -31,7 +34,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
struct AttentionActivations {
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
|
|
@ -43,24 +45,32 @@ struct AttentionActivations {
|
|||
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
|
||||
}
|
||||
|
||||
struct AttentionActivations {
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
|
||||
const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: config(config),
|
||||
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
// and does not use an external KV cache.
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
// MHA and does not use an external KV cache.
|
||||
q(MatFactory("q", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
q_bf(MatFactory("q_bf", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
q_T(MatFactory("q_T", layer_config.qkv_dim,
|
||||
config.vocab_size == 0
|
||||
? batch_size * layer_config.heads * 3
|
||||
: batch_size * layer_config.heads,
|
||||
allocator)),
|
||||
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
|
||||
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
|
||||
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
|
||||
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
||||
|
|
@ -68,6 +78,10 @@ struct AttentionActivations {
|
|||
att_out(MatFactory("att_out", batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
|
||||
allocator)),
|
||||
softmax_d(
|
||||
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
|
||||
att_sums(
|
||||
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
|
||||
|
||||
|
|
@ -76,11 +90,7 @@ struct AttentionActivations {
|
|||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale_global(CreateInvTimescale(
|
||||
allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
|
||||
// Batch size can be 0 in experimental code so do not assert.
|
||||
if (batch_size == 0) {
|
||||
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||
|
|
@ -94,44 +104,153 @@ struct AttentionActivations {
|
|||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
softmax_max.OverrideRows(batch_size);
|
||||
softmax_d.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
||||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
||||
MatStorageT<BF16> q_bf;
|
||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
||||
MatStorageT<float> vit_Q;
|
||||
MatStorageT<float> vit_K;
|
||||
MatStorageT<float> vit_C;
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
|
||||
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
|
||||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
};
|
||||
|
||||
// A non-owning view of AttentionActivations.
|
||||
struct AttentionActivationsPtrs {
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
|
||||
: config(config),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
|
||||
query_scale(ChooseQueryScale(config)) {}
|
||||
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||
const AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len) {
|
||||
q = activations.q;
|
||||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
vit_Q = activations.vit_Q;
|
||||
vit_K = activations.vit_K;
|
||||
vit_C = activations.vit_C;
|
||||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
softmax_max = activations.softmax_max;
|
||||
softmax_d = activations.softmax_d;
|
||||
att_sums = activations.att_sums;
|
||||
inv_timescale = activations.inv_timescale;
|
||||
inv_timescale_global = activations.inv_timescale_global;
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
softmax_max.OverrideRows(batch_size);
|
||||
softmax_d.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
size_t SeqLen() const {
|
||||
return static_cast<size_t>(div_seq_len.GetDivisor());
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
// For the matrices below, the batch_size dimension is really qbatch.Size() *
|
||||
// token_batch_size, but in all known uses, one of those is 1. Specifically,
|
||||
// during PrefillTBatch, it is prompt length (up to some max batch size)
|
||||
// and otherwise it's qbatch.Size().
|
||||
|
||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||
MatPtrT<float> q;
|
||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||
MatPtrT<BF16> q_bf;
|
||||
// Transposed query matrix for faster Q*K^T.
|
||||
MatPtrT<BF16> q_T;
|
||||
|
||||
MatPtrT<float> vit_Q;
|
||||
MatPtrT<float> vit_K;
|
||||
MatPtrT<float> vit_C;
|
||||
|
||||
// Output of RMSNorm before attention, size batch_size x model_dim.
|
||||
MatPtrT<float> pre_att_rms_out;
|
||||
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
|
||||
// seq_len).
|
||||
MatPtrT<float> att;
|
||||
// Attention output computed from att * V, size batch_size x (q_heads *
|
||||
// qkv_dim).
|
||||
MatPtrT<float> att_out;
|
||||
// The maximum logit value encountered when computing att_out from att,
|
||||
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||
MatPtrT<float> softmax_max;
|
||||
// The sum of scaled exponentials when computing att_out from att,
|
||||
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||
MatPtrT<float> softmax_d;
|
||||
// Accumulation of attention outputs over heads, size batch_size x
|
||||
// model_dim.
|
||||
MatPtrT<BF16> att_sums;
|
||||
// Inverse timescales for RoPE computation.
|
||||
MatPtrT<float> inv_timescale;
|
||||
// Inverse timescales for global RoPE computation.
|
||||
MatPtrT<float> inv_timescale_global;
|
||||
// Divisor for faster division by sequence length.
|
||||
hwy::Divisor div_seq_len;
|
||||
// Unfortunately, some models have had non-power-of-two heads.
|
||||
// Divisor for faster division by number of heads.
|
||||
hwy::Divisor div_heads;
|
||||
// Query scaling factor for attention computation.
|
||||
float query_scale;
|
||||
};
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
ThreadingContext& ctx,
|
||||
Activations(const RuntimeConfig& runtime_config, const ModelConfig& config,
|
||||
size_t batch_size, size_t seq_len, ThreadingContext& ctx,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: layer_config(config.layer_configs[0]),
|
||||
|
||||
|
|
@ -150,8 +269,18 @@ struct Activations {
|
|||
ffw_out(
|
||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||
|
||||
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
|
||||
row_ptrs) {
|
||||
max_workers(ctx.pools.MaxWorkers()),
|
||||
s_ffw_in(config.num_layers, max_workers),
|
||||
s_ffw_hidden(config.num_layers, max_workers),
|
||||
s_ffw_out(config.num_layers, max_workers),
|
||||
|
||||
s_w_gating_einsum_w1(config.num_layers, max_workers),
|
||||
s_w_gating_einsum_w2(config.num_layers, max_workers),
|
||||
s_w_linear_w(config.num_layers, max_workers),
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
runtime_config, ctx.allocator, row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
|
|
@ -167,6 +296,12 @@ struct Activations {
|
|||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
~Activations() {
|
||||
s_ffw_in.ReduceAndPrint("ffw_in");
|
||||
s_ffw_hidden.ReduceAndPrint("ffw_hidden");
|
||||
s_ffw_out.ReduceAndPrint("ffw_out");
|
||||
}
|
||||
|
||||
// Negligible CPU time.
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
x.OverrideRows(batch_size);
|
||||
|
|
@ -179,6 +314,9 @@ struct Activations {
|
|||
C2.OverrideRows(batch_size);
|
||||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
attention_storage.SetBatchSize(batch_size);
|
||||
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
|
||||
// their row override is not updated when the underlying storage changes.
|
||||
attention.SetBatchSize(batch_size);
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +333,19 @@ struct Activations {
|
|||
MatStorageT<BF16> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
|
||||
AttentionActivations attention;
|
||||
const size_t max_workers;
|
||||
TensorStats s_ffw_in;
|
||||
TensorStats s_ffw_hidden; // after Activation+gating
|
||||
TensorStats s_ffw_out;
|
||||
|
||||
TensorStats s_w_gating_einsum_w1;
|
||||
TensorStats s_w_gating_einsum_w2;
|
||||
TensorStats s_w_linear_w;
|
||||
|
||||
AttentionImpl attention_impl;
|
||||
|
||||
AttentionActivations attention_storage;
|
||||
AttentionActivationsPtrs attention;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,18 +15,22 @@
|
|||
|
||||
// Test client for API server
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "httplib.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// ANSI color codes
|
||||
const std::string RESET = "\033[0m";
|
||||
const std::string BOLD = "\033[1m";
|
||||
|
|
@ -38,8 +42,14 @@ const std::string RED = "\033[31m";
|
|||
|
||||
class APIClient {
|
||||
public:
|
||||
APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b")
|
||||
: host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) {
|
||||
APIClient(const std::string& host, int port, const std::string& api_key = "",
|
||||
const std::string& model = "gemma3-4b")
|
||||
: host_(host),
|
||||
port_(port),
|
||||
api_key_(api_key),
|
||||
model_(model),
|
||||
use_https_(port == 443),
|
||||
interactive_mode_(false) {
|
||||
if (use_https_) {
|
||||
ssl_client_ = std::make_unique<httplib::SSLClient>(host, port);
|
||||
ssl_client_->set_read_timeout(60, 0);
|
||||
|
|
@ -58,7 +68,9 @@ public:
|
|||
|
||||
std::string endpoint;
|
||||
if (is_public_api) {
|
||||
endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
||||
endpoint =
|
||||
stream
|
||||
? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse"
|
||||
: "/v1beta/models/gemini-2.0-flash:generateContent";
|
||||
} else {
|
||||
endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent"
|
||||
|
|
@ -67,7 +79,8 @@ public:
|
|||
|
||||
// Only show verbose output in non-interactive mode
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl;
|
||||
std::cout << "Request: " << request.dump(2) << std::endl;
|
||||
}
|
||||
|
||||
|
|
@ -83,18 +96,21 @@ public:
|
|||
json response = ProcessRequest(request, stream);
|
||||
|
||||
if (response.contains("error")) {
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void TestListModels() {
|
||||
std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl;
|
||||
|
||||
httplib::Headers headers;
|
||||
if (!api_key_.empty()) {
|
||||
headers.emplace("X-goog-api-key", api_key_);
|
||||
}
|
||||
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers);
|
||||
auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers)
|
||||
: client_->Get("/v1beta/models", headers);
|
||||
|
||||
if (res && res->status == 200) {
|
||||
json response = json::parse(res->body);
|
||||
|
|
@ -106,7 +122,9 @@ public:
|
|||
}
|
||||
|
||||
void InteractiveChat() {
|
||||
std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << CYAN << "💬 Interactive Chat Mode (with session)"
|
||||
<< RESET << std::endl;
|
||||
std::cout << "Type ':gemma %q' to end.\n" << std::endl;
|
||||
|
||||
interactive_mode_ = true;
|
||||
|
|
@ -141,14 +159,16 @@ public:
|
|||
|
||||
if (response.contains("candidates") && !response["candidates"].empty()) {
|
||||
auto& candidate = response["candidates"][0];
|
||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
||||
if (candidate.contains("content") &&
|
||||
candidate["content"].contains("parts")) {
|
||||
for (const auto& part : candidate["content"]["parts"]) {
|
||||
if (part.contains("text")) {
|
||||
std::string assistant_response = part["text"].get<std::string>();
|
||||
|
||||
// For streaming, the response is already displayed in real-time
|
||||
// Just add to message history for context
|
||||
json assistant_message = {{"parts", {{{"text", assistant_response}}}}};
|
||||
json assistant_message = {
|
||||
{"parts", {{{"text", assistant_response}}}}};
|
||||
if (!api_key_.empty()) {
|
||||
assistant_message["role"] = "model";
|
||||
}
|
||||
|
|
@ -157,7 +177,8 @@ public:
|
|||
}
|
||||
}
|
||||
} else if (response.contains("error")) {
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Error: " << response["error"]["message"]
|
||||
<< RESET << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
|
@ -165,14 +186,11 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) {
|
||||
json CreateAPIRequest(const std::string& prompt,
|
||||
const json& messages = json::array()) {
|
||||
json request = {
|
||||
{"generationConfig", {
|
||||
{"temperature", 0.9},
|
||||
{"topK", 1},
|
||||
{"maxOutputTokens", 1024}
|
||||
}}
|
||||
};
|
||||
{"generationConfig",
|
||||
{{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}};
|
||||
|
||||
if (messages.empty()) {
|
||||
// Single prompt
|
||||
|
|
@ -189,38 +207,42 @@ private:
|
|||
return request;
|
||||
}
|
||||
|
||||
json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) {
|
||||
json ProcessNonStreamingRequest(const json& request,
|
||||
const std::string& endpoint) {
|
||||
httplib::Headers headers = {{"Content-Type", "application/json"}};
|
||||
if (!api_key_.empty()) {
|
||||
headers.emplace("X-goog-api-key", api_key_);
|
||||
}
|
||||
|
||||
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json")
|
||||
: client_->Post(endpoint, headers, request.dump(), "application/json");
|
||||
auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(),
|
||||
"application/json")
|
||||
: client_->Post(endpoint, headers, request.dump(),
|
||||
"application/json");
|
||||
|
||||
if (res && res->status == 200) {
|
||||
json response = json::parse(res->body);
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
||||
std::cout << "\n"
|
||||
<< BOLD << GREEN << "📥 Response:" << RESET << std::endl;
|
||||
std::cout << response.dump(2) << std::endl;
|
||||
}
|
||||
return response;
|
||||
} else {
|
||||
json error_response = {
|
||||
{"error", {
|
||||
{"message", "Request failed"},
|
||||
{"status", res ? res->status : -1}
|
||||
}}
|
||||
};
|
||||
json error_response = {{"error",
|
||||
{{"message", "Request failed"},
|
||||
{"status", res ? res->status : -1}}}};
|
||||
if (res && !res->body.empty()) {
|
||||
error_response["error"]["details"] = res->body;
|
||||
}
|
||||
std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl;
|
||||
std::cerr << RED
|
||||
<< "❌ Request failed. Status: " << (res ? res->status : -1)
|
||||
<< RESET << std::endl;
|
||||
return error_response;
|
||||
}
|
||||
}
|
||||
|
||||
json ProcessStreamingRequest(const json& request, const std::string& endpoint) {
|
||||
json ProcessStreamingRequest(const json& request,
|
||||
const std::string& endpoint) {
|
||||
std::string accumulated_response;
|
||||
|
||||
// Use same SSE logic for both public and local APIs
|
||||
|
|
@ -233,7 +255,9 @@ private:
|
|||
}
|
||||
req.body = request.dump();
|
||||
|
||||
req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool {
|
||||
req.content_receiver = [&accumulated_response, this](
|
||||
const char* data, size_t data_length,
|
||||
uint64_t offset, uint64_t total_length) -> bool {
|
||||
std::string chunk(data, data_length);
|
||||
std::istringstream stream(chunk);
|
||||
std::string line;
|
||||
|
|
@ -244,14 +268,18 @@ private:
|
|||
|
||||
if (event_data == "[DONE]") {
|
||||
if (!interactive_mode_) {
|
||||
std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl;
|
||||
std::cout << "\n\n"
|
||||
<< GREEN << "✅ Generation complete!" << RESET
|
||||
<< std::endl;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
json event = json::parse(event_data);
|
||||
if (event.contains("candidates") && !event["candidates"].empty()) {
|
||||
if (event.contains("candidates") &&
|
||||
!event["candidates"].empty()) {
|
||||
auto& candidate = event["candidates"][0];
|
||||
if (candidate.contains("content") && candidate["content"].contains("parts")) {
|
||||
if (candidate.contains("content") &&
|
||||
candidate["content"].contains("parts")) {
|
||||
for (const auto& part : candidate["content"]["parts"]) {
|
||||
if (part.contains("text")) {
|
||||
std::string text = part["text"].get<std::string>();
|
||||
|
|
@ -272,27 +300,22 @@ private:
|
|||
|
||||
httplib::Response res;
|
||||
httplib::Error error;
|
||||
bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error);
|
||||
bool success = use_https_ ? ssl_client_->send(req, res, error)
|
||||
: client_->send(req, res, error);
|
||||
|
||||
if (res.status == 200 && !accumulated_response.empty()) {
|
||||
return json{
|
||||
{"candidates", {{
|
||||
{"content", {
|
||||
{"parts", {{{"text", accumulated_response}}}}
|
||||
}}
|
||||
}}}
|
||||
};
|
||||
{"candidates",
|
||||
{{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}};
|
||||
} else {
|
||||
json error_response = {
|
||||
{"error", {
|
||||
{"message", "Streaming request failed"},
|
||||
{"status", res.status}
|
||||
}}
|
||||
};
|
||||
{"error",
|
||||
{{"message", "Streaming request failed"}, {"status", res.status}}}};
|
||||
if (!res.body.empty()) {
|
||||
error_response["error"]["details"] = res.body;
|
||||
}
|
||||
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl;
|
||||
std::cerr << RED << "❌ Streaming request failed. Status: " << res.status
|
||||
<< RESET << std::endl;
|
||||
return error_response;
|
||||
}
|
||||
}
|
||||
|
|
@ -308,19 +331,55 @@ private:
|
|||
bool interactive_mode_;
|
||||
};
|
||||
|
||||
struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||
ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
ClientArgs() { Init(); };
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
std::string prompt;
|
||||
bool interactive;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(host, "host", std::string("localhost"),
|
||||
"Server host (default: localhost)");
|
||||
visitor(port, "port", 8080, "Server port (default: 8080)");
|
||||
visitor(api_key, "api_key", std::string(""),
|
||||
"Use public API with key (changes host to "
|
||||
"generativelanguage.googleapis.com:443)");
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name to use (default: gemma3-4b)");
|
||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||
"Prompt for generation (default: 'Hello! How are you?')");
|
||||
visitor(interactive, "interactive", false,
|
||||
"Start interactive chat mode (0 = no, 1 = yes)");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gcpp::ClientArgs client_args(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::ClientArgs client_args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cout << "\nAPI Client for gemma.cpp\n";
|
||||
std::cout << "========================\n\n";
|
||||
fprintf(stderr,
|
||||
"\nAPI Client for gemma.cpp\n"
|
||||
"========================\n\n");
|
||||
client_args.Help();
|
||||
std::cout << std::endl;
|
||||
std::cout << "Environment Variables:" << std::endl;
|
||||
std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl;
|
||||
fprintf(stderr,
|
||||
"\n*Environment Variables:\n"
|
||||
" GOOGLE_API_KEY : Automatically use public Google API if set\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// Check for GOOGLE_API_KEY environment variable
|
||||
const char* env_api_key = std::getenv("GOOGLE_API_KEY");
|
||||
if (env_api_key != nullptr && strlen(env_api_key) > 0) {
|
||||
|
|
@ -335,11 +394,12 @@ int main(int argc, char* argv[]) {
|
|||
client_args.port = 443;
|
||||
}
|
||||
|
||||
std::cout << BOLD << YELLOW << "🚀 Testing API Server at "
|
||||
<< client_args.host << ":" << client_args.port << RESET << std::endl;
|
||||
std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host
|
||||
<< ":" << client_args.port << RESET << std::endl;
|
||||
|
||||
try {
|
||||
APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model);
|
||||
APIClient client(client_args.host, client_args.port, client_args.api_key,
|
||||
client_args.model);
|
||||
|
||||
if (client_args.interactive) {
|
||||
client.InteractiveChat();
|
||||
|
|
@ -347,11 +407,12 @@ int main(int argc, char* argv[]) {
|
|||
client.TestListModels();
|
||||
client.TestGenerateContent(client_args.prompt, true);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl;
|
||||
std::cerr << "Make sure the API server is running:" << std::endl;
|
||||
std::cerr << " ./build/gemma_api_server --tokenizer <path> --weights <path>" << std::endl;
|
||||
std::cerr
|
||||
<< " ./build/gemma_api_server --tokenizer <path> --weights <path>"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,22 +15,19 @@
|
|||
|
||||
// HTTP API server for gemma.cpp with SSE support
|
||||
|
||||
#include <stdio.h>
|
||||
#include <signal.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// HTTP server library
|
||||
#undef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
|
@ -38,16 +35,12 @@
|
|||
#include "httplib.h"
|
||||
|
||||
// JSON library
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -90,7 +83,8 @@ struct ServerState {
|
|||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||
auto& session = sessions[session_id];
|
||||
if (!session.kv_cache) {
|
||||
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
||||
session.kv_cache = std::make_unique<KVCache>(
|
||||
gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
||||
}
|
||||
session.last_access = std::chrono::steady_clock::now();
|
||||
return session;
|
||||
|
|
@ -107,7 +101,8 @@ std::string GenerateSessionId() {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
// Wraps messages with start_of_turn markers - handles both with and without roles
|
||||
// Wraps messages with start_of_turn markers - handles both with and without
|
||||
// roles
|
||||
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||
std::string prompt;
|
||||
|
||||
|
|
@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
|||
std::string text = part["text"];
|
||||
|
||||
if (role == "user") {
|
||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
prompt +=
|
||||
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
} else if (role == "model") {
|
||||
prompt += text + "\n";
|
||||
} else if (role.empty()) {
|
||||
// Local format without roles - for now, treat as user input
|
||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
prompt +=
|
||||
"<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) {
|
|||
return config;
|
||||
}
|
||||
|
||||
// Unified response formatter - creates consistent format regardless of request type
|
||||
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
|
||||
// Unified response formatter - creates consistent format regardless of request
|
||||
// type
|
||||
json CreateAPIResponse(const std::string& text,
|
||||
bool is_streaming_chunk = false) {
|
||||
json response = {
|
||||
{"candidates", {{
|
||||
{"content", {
|
||||
{"parts", {{{"text", text}}}},
|
||||
{"role", "model"}
|
||||
}},
|
||||
{"index", 0}
|
||||
}}},
|
||||
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
||||
};
|
||||
{"candidates",
|
||||
{{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}},
|
||||
{"index", 0}}}},
|
||||
{"promptFeedback", {{"safetyRatings", json::array()}}}};
|
||||
|
||||
// Only add finishReason for non-streaming chunks
|
||||
if (!is_streaming_chunk) {
|
||||
|
|
@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
|||
}
|
||||
|
||||
// Handle generateContent endpoint (non-streaming)
|
||||
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleGenerateContentNonStreaming(ServerState& state,
|
||||
const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
|
@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||
} else {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
res.set_content(
|
||||
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||
"application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
// Set up runtime config
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||
|
||||
// Collect full response
|
||||
std::string full_response;
|
||||
runtime_config.stream_token = [&full_response](int token, float) {
|
||||
// Skip EOS token
|
||||
return true;
|
||||
};
|
||||
runtime_config.stream_token = [](int token, float) { return true; };
|
||||
|
||||
// Tokenize prompt
|
||||
std::vector<int> tokens = WrapAndTokenize(
|
||||
|
|
@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
|
||||
// Temporarily redirect output to capture response
|
||||
std::stringstream output;
|
||||
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
||||
runtime_config.stream_token = [&output, &state, &session, &tokens](
|
||||
int token, float) {
|
||||
// Skip prompt tokens
|
||||
if (session.abs_pos < tokens.size()) {
|
||||
session.abs_pos++;
|
||||
|
|
@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
}
|
||||
|
||||
// Handle streamGenerateContent endpoint with SSE)
|
||||
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleGenerateContentStreaming(ServerState& state,
|
||||
const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
|
@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
||||
} else {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
res.set_content(
|
||||
json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(),
|
||||
"application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -305,8 +303,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
|
||||
// Set up chunked content provider for SSE
|
||||
res.set_chunked_content_provider(
|
||||
"text/event-stream",
|
||||
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
|
||||
"text/event-stream", [&state, request, prompt, session_id](
|
||||
size_t offset, httplib::DataSink& sink) {
|
||||
try {
|
||||
// Lock for inference
|
||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||
|
|
@ -338,7 +336,8 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
|
||||
// Decode token
|
||||
std::string token_text;
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token},
|
||||
&token_text);
|
||||
accumulated_text += token_text;
|
||||
|
||||
// Send SSE event using unified formatter
|
||||
|
|
@ -365,8 +364,7 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
final_event["usageMetadata"] = {
|
||||
{"promptTokenCount", tokens.size()},
|
||||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||
{"totalTokenCount", session.abs_pos}
|
||||
};
|
||||
{"totalTokenCount", session.abs_pos}};
|
||||
|
||||
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
||||
sink.write(final_sse.data(), final_sse.size());
|
||||
|
|
@ -377,16 +375,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
// Ensure all data is sent
|
||||
sink.done();
|
||||
return false; // End streaming
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
json error_event = {{"error", {{"message", e.what()}}}};
|
||||
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
||||
sink.write(error_sse.data(), error_sse.size());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
});
|
||||
} catch (const json::exception& e) {
|
||||
res.status = 400;
|
||||
res.set_content(
|
||||
|
|
@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
}
|
||||
|
||||
// Handle models list endpoint
|
||||
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
|
||||
void HandleListModels(ServerState& state, const InferenceArgs& inference,
|
||||
const httplib::Request& req, httplib::Response& res) {
|
||||
json response = {
|
||||
{"models", {{
|
||||
{"name", "models/" + inference.model},
|
||||
{"models",
|
||||
{{{"name", "models/" + inference.model},
|
||||
{"version", "001"},
|
||||
{"displayName", inference.model},
|
||||
{"description", inference.model + " model running locally"},
|
||||
{"inputTokenLimit", 8192},
|
||||
{"outputTokenLimit", 8192},
|
||||
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
|
||||
{"supportedGenerationMethods",
|
||||
json::array({"generateContent", "streamGenerateContent"})},
|
||||
{"temperature", 1.0},
|
||||
{"topK", 1}
|
||||
}}}
|
||||
};
|
||||
{"topK", 1}}}}};
|
||||
|
||||
res.set_content(response.dump(), "application/json");
|
||||
}
|
||||
|
|
@ -421,37 +416,43 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
|||
// server_running = false;
|
||||
// }
|
||||
|
||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
void RunServer(const GemmaArgs& args) {
|
||||
std::cerr << "Loading model..." << std::endl;
|
||||
|
||||
// Initialize model
|
||||
ThreadingContext ctx(threading);
|
||||
ThreadingContext ctx(args.threading);
|
||||
MatMulEnv env(ctx);
|
||||
|
||||
ServerState state;
|
||||
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
||||
state.gemma = std::make_unique<Gemma>(args, ctx);
|
||||
state.env = &env;
|
||||
state.ctx = &ctx;
|
||||
|
||||
const InferenceArgs& inference = args.inference;
|
||||
|
||||
httplib::Server server;
|
||||
|
||||
// Set up routes
|
||||
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
||||
server.Get(
|
||||
"/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" +
|
||||
inference.model + ":generateContent",
|
||||
"text/plain");
|
||||
});
|
||||
|
||||
// API endpoints
|
||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req,
|
||||
httplib::Response& res) {
|
||||
HandleListModels(state, inference, req, res);
|
||||
});
|
||||
|
||||
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
||||
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Post(model_endpoint + ":generateContent",
|
||||
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentNonStreaming(state, req, res);
|
||||
});
|
||||
|
||||
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
server.Post(model_endpoint + ":streamGenerateContent",
|
||||
[&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentStreaming(state, req, res);
|
||||
});
|
||||
|
||||
|
|
@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
||||
std::cerr << "Model loaded successfully" << std::endl;
|
||||
std::cerr << "Endpoints:" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent"
|
||||
<< std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model
|
||||
<< ":streamGenerateContent (SSE)" << std::endl;
|
||||
std::cerr << " GET /v1beta/models" << std::endl;
|
||||
|
||||
if (!server.listen("0.0.0.0", inference.port)) {
|
||||
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
||||
std::cerr << "Failed to start server on port " << inference.port
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
cleanup_thread.join();
|
||||
|
|
@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
int main(int argc, char** argv) {
|
||||
gcpp::InternalInit();
|
||||
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
||||
std::cout << "========================\n\n";
|
||||
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
|
||||
std::cerr << "\nOptions:\n";
|
||||
std::cerr << " --port PORT Server port (default: 8080)\n";
|
||||
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
|
||||
std::cerr << "\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n";
|
||||
fprintf(
|
||||
stderr,
|
||||
"\n\nAPI server for gemma.cpp\n"
|
||||
"========================\n\n"
|
||||
" --port PORT Server port (default: 8080)\n"
|
||||
" --model MODEL Model name for endpoints (default: gemma3-4b)\n");
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Arguments are now handled by InferenceArgs
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
// // Set up signal handler
|
||||
// signal(SIGINT, gcpp::HandleShutdown);
|
||||
// signal(SIGTERM, gcpp::HandleShutdown);
|
||||
|
||||
gcpp::RunServer(loader, threading, inference);
|
||||
gcpp::RunServer(args);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/flash_attention.h"
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -57,33 +58,35 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
|
||||
|
||||
CompressPerThread tls;
|
||||
const hn::ScalableTag<float> df;
|
||||
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
||||
0);
|
||||
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const float score = Dot(q, k.Row(pos), k.Cols());
|
||||
const float score =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
|
||||
att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
||||
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
|
||||
att[pos_modulo] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
ThreadingContext& ctx, const size_t worker,
|
||||
const size_t pos, const float mul) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer && IsVLM(activations.config.model)) {
|
||||
if (is_global_layer && activations.config.use_global_timescale) {
|
||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
|
|
@ -104,62 +107,52 @@ static HWY_INLINE void WeightedSumV(
|
|||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
// TODO: 2..4x unroll
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
|
||||
// TODO: replace with MatMul(att, v) after it supports non-transposed B.
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
|
||||
worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
||||
}
|
||||
} else {
|
||||
{
|
||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
|
||||
worker);
|
||||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the attention outputs for a single q, which may be updated
|
||||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||
const MatPtr& query_norm_scale, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, const SMOptions& sm_options,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(kv_last_pos < activations.SeqLen());
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||
QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx,
|
||||
worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
const Logits logits(att, att_len);
|
||||
const Logits logits(att, kv_last_pos + 1);
|
||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
|
||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
ctx, worker);
|
||||
WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v,
|
||||
att_out, ctx, worker);
|
||||
}
|
||||
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
|
|
@ -170,13 +163,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
|
|||
}
|
||||
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
const MatPtr& query_norm_scale,
|
||||
AttentionActivationsPtrs& activations,
|
||||
QBatch& qbatch, ThreadingContext& ctx) {
|
||||
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
|
|
@ -184,8 +177,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
const size_t seq_len = activations.SeqLen();
|
||||
// All layers should have the same number of heads.
|
||||
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
||||
|
||||
|
|
@ -196,12 +188,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
const size_t token_idx = div_qbatch.Divide(tq_idx);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t pos = qbatch.Pos(qi) + token_idx;
|
||||
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
|
|
@ -214,6 +206,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||
SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head,
|
||||
.d_out = activations.softmax_d.Row(tq_idx) + head};
|
||||
|
||||
// Make strided read-only views into the kv cache for
|
||||
// this query and head.
|
||||
|
|
@ -224,8 +218,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, ctx, worker);
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v,
|
||||
query_norm_scale, layer_idx, activations, att,
|
||||
att_out, sm_options, ctx, worker);
|
||||
};
|
||||
|
||||
{
|
||||
|
|
@ -246,7 +242,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
AttentionActivationsPtrs& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
|
||||
|
|
@ -271,10 +267,14 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
// Index into qbatch, within [0, qbatch.Size()]
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos =
|
||||
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
||||
// Index along token sequence, within [0, num_tokens)
|
||||
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||
|
||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||
}
|
||||
|
|
@ -286,15 +286,16 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
// tasks are very lightweight.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||
[&](size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
|
||||
// --seq_len must be large enough to avoid wraparound.
|
||||
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
|
|
@ -313,35 +314,18 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
|
||||
worker, pos, /*mul=*/1.0f);
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||
cache_pos + offset,
|
||||
/*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
|
||||
|
||||
|
|
@ -353,13 +337,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
if (flags & kAttentionUseOld) {
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
|
||||
activations, qbatch, env.ctx);
|
||||
} else {
|
||||
// * 2 does not help on Turin.
|
||||
FlashAttention(num_tokens,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||
layer_idx, layer, activations, qbatch, env.ctx);
|
||||
layer_idx, layer.query_norm_scale, activations, qbatch,
|
||||
env.ctx);
|
||||
}
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ namespace gcpp {
|
|||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||
float mul); \
|
||||
\
|
||||
|
|
@ -39,18 +38,18 @@ namespace gcpp {
|
|||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, \
|
||||
QBatch& qbatch, ThreadingContext& ctx); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
|
|
|||
|
|
@ -0,0 +1,570 @@
|
|||
#include <cstddef>
|
||||
#include <cstring> // strcmp
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
// These tests aren't designed to suss out instruction set specific problems.
|
||||
// Disable most targets to keep the tests fast and simple and not have to
|
||||
// worry about tolerances on floating point results.
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/attention_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/attention.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void FillRandom(MatPtrT<float>& mat, uint64_t seed) {
|
||||
hwy::RandomState rng(seed);
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
float* row = mat.Row(r);
|
||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||
row[c] = static_cast<float>(RandomGaussian(rng));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AllocateAndFillRandom(MatPtr& mat, const Allocator& allocator,
|
||||
std::vector<MatOwner>& mat_owners, uint64_t seed) {
|
||||
if (mat.IsEmpty()) return;
|
||||
if (mat.GetType() == Type::kUnknown) {
|
||||
mat.SetType(Type::kF32);
|
||||
}
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(mat, allocator, MatPadding::kPacked);
|
||||
MatPtrT<float> mat_f32(mat);
|
||||
FillRandom(mat_f32, seed);
|
||||
}
|
||||
|
||||
struct TestState {
|
||||
TestState() : ctx({}), env(ctx) {}
|
||||
ThreadingContext ctx;
|
||||
std::vector<MatOwner> mat_owners;
|
||||
MatMulEnv env;
|
||||
};
|
||||
|
||||
struct TestModelState {
|
||||
TestModelState(TestState& state)
|
||||
: config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT),
|
||||
tensor_info_registry(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
layer(0, layer_config, tensor_info_registry) {
|
||||
config.att_cap = 1024.0f;
|
||||
AllocateAndFillRandom(layer.qkv_einsum_w, state.ctx.allocator,
|
||||
state.mat_owners, 42);
|
||||
AllocateAndFillRandom(layer.attn_vec_einsum_w, state.ctx.allocator,
|
||||
state.mat_owners, 43);
|
||||
AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator,
|
||||
state.mat_owners, 44);
|
||||
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners,
|
||||
45);
|
||||
layer.Fixup(state.mat_owners, state.ctx);
|
||||
}
|
||||
|
||||
ModelConfig config;
|
||||
TensorInfoRegistry tensor_info_registry;
|
||||
const LayerConfig& layer_config;
|
||||
LayerWeightsPtrs layer;
|
||||
};
|
||||
|
||||
struct TestAttentionState {
|
||||
TestAttentionState(TestState& state, TestModelState& model_state,
|
||||
size_t num_tokens, size_t qbatch_size,
|
||||
AttentionImpl attention_impl)
|
||||
: num_tokens(num_tokens),
|
||||
qbatch_size(qbatch_size),
|
||||
batch_size(qbatch_size * num_tokens),
|
||||
runtime_config{.attention_impl = attention_impl},
|
||||
tokens(num_tokens),
|
||||
attention_storage_(model_state.config, model_state.layer_config,
|
||||
batch_size, num_tokens, runtime_config,
|
||||
state.ctx.allocator, row_ptrs_),
|
||||
attention(model_state.config, num_tokens, attention_storage_) {
|
||||
for (size_t i = 0; i < qbatch_size; ++i) {
|
||||
kv_caches.emplace_back(model_state.config, inference_args,
|
||||
state.ctx.allocator);
|
||||
}
|
||||
activations.emplace(
|
||||
runtime_config, model_state.config, runtime_config.prefill_tbatch_size,
|
||||
kv_caches[0].SeqLen(), state.env.ctx, state.env.row_ptrs);
|
||||
// Tokens don't matter, since we fill in pre_att_rms_out before calling
|
||||
// GemmaAttention.
|
||||
std::iota(tokens.begin(), tokens.end(), 1);
|
||||
for (size_t i = 0; i < qbatch_size; ++i) {
|
||||
prompts.emplace_back(tokens);
|
||||
}
|
||||
all_queries.emplace(prompts,
|
||||
hwy::Span<KVCache>(kv_caches.data(), kv_caches.size()));
|
||||
qbatch.emplace(/*start=*/0, /*max_size=*/qbatch_size, *all_queries);
|
||||
FillRandom(attention.pre_att_rms_out, 46);
|
||||
}
|
||||
|
||||
const size_t num_tokens;
|
||||
const size_t qbatch_size;
|
||||
const size_t batch_size;
|
||||
InferenceArgs inference_args;
|
||||
RuntimeConfig runtime_config;
|
||||
std::vector<KVCache> kv_caches;
|
||||
std::optional<Activations> activations;
|
||||
std::vector<int> tokens;
|
||||
std::vector<PromptTokens> prompts;
|
||||
std::optional<AllQueries> all_queries;
|
||||
std::optional<QBatch> qbatch;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs_;
|
||||
AttentionActivations attention_storage_;
|
||||
AttentionActivationsPtrs attention;
|
||||
};
|
||||
|
||||
double GetTolerance() {
|
||||
const char* target_name = hwy::TargetName(HWY_TARGET);
|
||||
if (strncmp(target_name, "AVX2", 4) == 0) {
|
||||
return 2e-2;
|
||||
} else if (strncmp(target_name, "AVX3", 4) == 0) {
|
||||
return 3e-4;
|
||||
} else if (strncmp(target_name, "NEON", 4) == 0) {
|
||||
return 5e-3;
|
||||
} else {
|
||||
return 1e-7;
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
|
||||
void CompareAttSumsWithGolden(
|
||||
const AttentionActivationsPtrs& attention,
|
||||
const float (&golden)[kNumTokens][kQBatchSize][kDims]) {
|
||||
ASSERT_EQ(attention.att_sums.Rows(), kNumTokens * kQBatchSize);
|
||||
ASSERT_LE(kDims, attention.att_sums.Cols());
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_row =
|
||||
hwy::AllocateAligned<float>(kDims);
|
||||
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
|
||||
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
|
||||
const size_t i = token_idx * kQBatchSize + qi;
|
||||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
|
||||
void CompareKVCacheWithGolden(
|
||||
const ModelConfig& config, hwy::Span<KVCache> kv_caches, const size_t layer,
|
||||
const size_t kv_head,
|
||||
const float (&k_golden)[kNumTokens][kQBatchSize][kDims],
|
||||
const float (&v_golden)[kNumTokens][kQBatchSize][kDims]) {
|
||||
const size_t qbatch_size = kv_caches.size();
|
||||
ASSERT_EQ(kQBatchSize, qbatch_size);
|
||||
const size_t start_offset = 0;
|
||||
const size_t qkv_dim = config.layer_configs[0].qkv_dim;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_k_row =
|
||||
hwy::AllocateAligned<float>(kDims);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_v_row =
|
||||
hwy::AllocateAligned<float>(kDims);
|
||||
|
||||
const size_t cache_layer_size = config.layer_configs[layer].CacheLayerSize();
|
||||
const size_t head_offset = kv_head * qkv_dim * 2;
|
||||
const size_t kv_offset = layer * cache_layer_size + head_offset;
|
||||
|
||||
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
|
||||
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
|
||||
const float* cache_row =
|
||||
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
|
||||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_k_row[j] = cache_row[kv_offset + j];
|
||||
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
|
||||
void CompareQVecsWithGolden(
|
||||
const ModelConfig& config, const AttentionActivationsPtrs& attention,
|
||||
const size_t q_head,
|
||||
const float (&q_golden)[kNumTokens][kQBatchSize][kDims]) {
|
||||
const size_t qkv_dim = config.layer_configs[0].qkv_dim;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_q_row =
|
||||
hwy::AllocateAligned<float>(kDims);
|
||||
|
||||
const size_t head_offset = q_head * qkv_dim;
|
||||
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
|
||||
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
|
||||
const size_t i = token_idx * kQBatchSize + qi;
|
||||
const float* q_row = attention.q.Row(i);
|
||||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_q_row[j] = q_row[head_offset + j];
|
||||
}
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " q_head=" << q_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr size_t kNumTokens = 10;
|
||||
constexpr size_t kQBatchSize = 2;
|
||||
|
||||
// Only compare the first 10 dimensions of the attention output to keep the
|
||||
// test manageable. It's unlikely that the first few dimensions would have
|
||||
// different behavior than the later dimensions.
|
||||
const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats
|
||||
|
||||
// Use a Coding Agent to update the golden values from test
|
||||
// failure output (which includes the new values).
|
||||
// Instructions to the Coding Agent: You must never update the
|
||||
// EMU128 tolerance. Always use the value from the EMU128 test to update the
|
||||
// Goldens. If necessary, add relaxed tolerance for other instruction sets.
|
||||
|
||||
// Layer 0
|
||||
const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
||||
{{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5,
|
||||
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
|
||||
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
|
||||
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
|
||||
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
|
||||
128, 27.25, -161, 19.125, -58, 97.5},
|
||||
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
|
||||
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
|
||||
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
|
||||
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
|
||||
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
|
||||
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
|
||||
-46, -22, -19.375, -16.125, -148, 20.875},
|
||||
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
10.9375, -52.5, 23.25, 75}},
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
|
||||
-34.75, 18, -52, 100, -186, -75.5, 50.75},
|
||||
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
|
||||
39.25, 65, 47.25, -89.5, -34.25, 137}},
|
||||
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
|
||||
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
|
||||
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
|
||||
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
|
||||
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
|
||||
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
|
||||
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
|
||||
88.5, 48.5, -46.25, -36.75, 101.5}},
|
||||
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
|
||||
125, -53.5, -14.625, 262},
|
||||
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
|
||||
81, -27.25, -3.03125, 111, -167, 59, 176}},
|
||||
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
|
||||
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
|
||||
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
|
||||
58.25, 173, -53.5, 25.125, 4.8125}},
|
||||
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
|
||||
-17.125, 15.25, 143, -27, 103, 101},
|
||||
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
|
||||
8.125, -99.5, 13.6875, -11.6875, 33}},
|
||||
};
|
||||
|
||||
// Layer 0, *K*V Head 0
|
||||
const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
||||
{{-4.51717567, 6.93118095, 6.48003578, 9.12825584, 2.38755274, 11.8121576,
|
||||
1.65376127, 5.04456615, -7.19549274, 2.57609844, 3.55331731, -3.48494458,
|
||||
-8.90498638, 9.66047478, -0.379868984, 6.37043715, -2.24351144},
|
||||
{0.152208567, 3.14520073, -8.35154343, 5.44226503, -6.74000502,
|
||||
-1.43484437, -4.72092056, -9.48932, -6.12409401, -1.55352509, -3.90701318,
|
||||
2.12124252, 3.93649936, -8.09877586, -3.30277514, -0.898857355,
|
||||
1.76684189}},
|
||||
{{4.378829, 5.05565643, -7.63948059, -5.74608946, 2.90109587, 0.155819178,
|
||||
4.56115055, 1.37885749, 1.48427355, -1.07145202, 2.82399392, -1.20864201,
|
||||
3.05434561, -2.65185618, -0.0731391, -8.2279253, 7.63228416},
|
||||
{-0.702698231, 1.49563932, 6.42149782, -6.68306589, 1.85317755,
|
||||
-7.70267582, 2.07357907, -7.60303402, -0.514724255, 0.308567047,
|
||||
5.99250412, -4.67359257, -3.49322176, -2.62086344, -3.18411255,
|
||||
2.04027057, -4.29057407}},
|
||||
{{-1.20844436, 4.14724302, 6.04515219, 8.7753458, -0.975198627, 0.564640105,
|
||||
5.39941597, 4.64036179, 0.366614938, 3.48258138, -0.470701456, 15.2267399,
|
||||
4.63302803, 9.12662697, -5.89148045, 2.25731587, 5.24449492},
|
||||
{4.57078934, -4.60315752, -3.3364439, 1.29875994, -3.40833569, -6.95262,
|
||||
-6.39040232, -6.60212612, 6.63269806, -0.815209687, -5.0346446,
|
||||
-4.13564968, 8.25674057, -6.0910182, -8.21130085, -8.91020393,
|
||||
10.6188011}},
|
||||
{{0.602011144, 2.22505236, 3.62411499, -4.07026958, 12.8036356, 3.76139069,
|
||||
6.99502087, 7.02500725, -2.51568675, 4.2489934, 0.00210827589,
|
||||
-1.43267739, -2.10394144, -0.0506809056, -1.54883039, 4.3740139,
|
||||
-1.61869526},
|
||||
{-6.37204599, -3.34989691, 2.10935307, 4.23634195, 5.79134035, 13.502944,
|
||||
-2.19158888, -1.55771351, -1.22244942, 3.36499929, -2.11375904,
|
||||
-4.5448761, 1.0611912, -2.47849369, -0.212709218, 0.363292456,
|
||||
7.91467094}},
|
||||
{{-8.85739231, -4.08585882, -0.618261, 6.52911091, 5.14922285, 7.6869874,
|
||||
0.750387549, -0.812200725, 2.7509625, 6.29693508, -1.77248931, 5.68896484,
|
||||
-6.9369607, -4.61359406, 0.184977874, -1.27769828, -2.1619854},
|
||||
{-8.2555, 2.84032059, -1.03791106, 2.07648611, -4.94546843, 1.76888537,
|
||||
-1.75901175, 11.2628574, 1.41086221, -3.58669901, -2.85925198, 2.29133463,
|
||||
1.55509436, -0.0553357825, -10.0363655, 1.94261, -2.95691729}},
|
||||
{{0.919141412, 1.97533965, -11.3202848, -3.3137629, -4.7161727, 5.07012081,
|
||||
1.76256621, 8.20588207, 6.05700159, -3.89765406, -1.13639557, -1.32326794,
|
||||
-3.01544905, -0.585309267, 2.60637712, 2.83708405, -3.39202118},
|
||||
{9.11918, 2.11261511, -5.87290621, 11.6033278, -4.66597795, -7.13774204,
|
||||
-9.10563755, -2.48294282, 3.35282946, -3.75122213, 0.404774547,
|
||||
-9.11625195, 4.85711479, 1.43184578, 1.47673059, -4.75093, -3.45323014}},
|
||||
{{4.17705393, -4.95192289, -10.5068378, 3.90004015, -3.51306129, 5.38068056,
|
||||
0.901511431, 11.222868, 2.67285442, 9.18779, 5.61346769, 3.06534624,
|
||||
-3.78898215, 0.767340839, 15.8207836, -4.14079094, -4.63177109},
|
||||
{3.61795235, -7.00262165, 2.08284521, -6.70515728, 1.93205631, 2.84467721,
|
||||
3.94591737, -6.18882942, -1.78465152, -9.39100933, -10.8780289,
|
||||
6.32468653, 6.53142738, -3.30765963, 2.89132166, 4.53347206, 1.89792418}},
|
||||
{{-0.361971855, -1.57735932, 5.07296801, -1.55669761, -1.44996238,
|
||||
7.29838896, 5.23075104, -0.512441278, -3.59834242, 2.38584423, 6.48518324,
|
||||
-1.48220074, -2.4264791, 10.7237988, 5.64735842, 5.6251297, -7.04244423},
|
||||
{-0.795628309, 7.30230665, -1.71035647, -16.6999454, 3.05102086,
|
||||
-4.9243927, 4.28508186, -0.694577456, 6.58464718, 4.40330124, 3.3250041,
|
||||
1.90579033, -6.29048729, 2.55308104, -4.9746747, -0.681708, -5.98152351}},
|
||||
{{2.57555652, -3.5651083, 0.784440041, -4.7043705, 2.37520599, -3.62385964,
|
||||
-3.48913693, -7.28049421, -5.48726082, 1.95519221, 7.25192928, 3.07074118,
|
||||
-11.9897156, 5.92244673, 5.07564354, 0.162699938, -6.00809956},
|
||||
{5.56260443, -5.7683115, 1.26402235, -17.507719, 4.18873024, -3.20694613,
|
||||
-4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766,
|
||||
-7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187,
|
||||
2.60931611}},
|
||||
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925,
|
||||
-4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064,
|
||||
0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337},
|
||||
{-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355,
|
||||
-1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147,
|
||||
-5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052,
|
||||
-7.43491}},
|
||||
};
|
||||
|
||||
// Layer 0, K*V* Head 0
|
||||
const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
||||
{{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423,
|
||||
-1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332,
|
||||
-4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026},
|
||||
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801,
|
||||
-10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463,
|
||||
0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}},
|
||||
{{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635,
|
||||
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508,
|
||||
-10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929},
|
||||
{-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228,
|
||||
3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248,
|
||||
-7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251,
|
||||
3.30500841}},
|
||||
{{-3.34342527, 6.03099537, 6.335958, 0.993818045, 0.905343294, 6.93058586,
|
||||
3.9635396, 10.8044815, 7.8620863, -10.1157322, -3.92666101, -0.183003783,
|
||||
-5.27309418, -1.45110512, -8.96734, -2.63866425, 2.19913912},
|
||||
{16.416317, -1.62025332, 2.3161006, 3.32571959, -1.79581594, -10.2925539,
|
||||
-5.86338425, -6.36642933, 9.18872166, 5.95524168, 6.38640785, 8.23832,
|
||||
-6.57342291, -14.2017632, 1.10925388, 4.27255058, -2.65661311}},
|
||||
{{6.58254147, -6.96165133, -4.97437, -2.33467388, 5.83671236, -0.794236898,
|
||||
-2.03117108, -3.93387103, -5.96872902, 5.83316422, 3.01795, -4.05260706,
|
||||
-4.39556885, 3.24399853, 10.1573639, 4.71967888, 0.274738848},
|
||||
{7.13243389, -8.04649162, 2.53055143, 2.0771277, -0.667295456, -13.0285645,
|
||||
0.960428238, -2.11983275, 8.18105602, -6.72609901, -5.46944714,
|
||||
0.204244614, 0.0900330544, 8.86620903, 4.63697529, 3.19756651,
|
||||
2.99392676}},
|
||||
{{9.52539158, -4.3840766, -6.94514465, -2.75913763, -10.8364506,
|
||||
-3.95606327, 2.43603897, -5.78482246, -0.801304817, 8.23436832,
|
||||
-7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822,
|
||||
16.0471916},
|
||||
{6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875,
|
||||
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392,
|
||||
1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}},
|
||||
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125,
|
||||
1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826,
|
||||
7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216},
|
||||
{7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912,
|
||||
0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329,
|
||||
3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}},
|
||||
{{-2.48950672, -8.55112743, 8.04663277, -5.77116871, -0.637019753,
|
||||
-7.65882111, -7.49037457, 3.8041625, -3.57038307, 9.37715435, -6.42604256,
|
||||
1.62610793, -1.54000568, 2.52110147, 5.30775261, -4.10454893,
|
||||
-4.96251774},
|
||||
{-2.95554614, -5.18210888, 1.00015664, -4.03864431, -7.14954519,
|
||||
5.99929142, 5.86350155, 2.03810191, -4.23009968, 9.39885902, -5.68198299,
|
||||
2.72845244, 11.7133255, 0.838779449, -13.2235403, 2.94607735,
|
||||
-2.7902379}},
|
||||
{{2.86876941, -0.836064458, -0.374509573, -0.277966499, 3.20654631,
|
||||
-3.68510771, -7.76134634, 2.23905277, -8.35530376, 5.25071716,
|
||||
-1.38490796, -2.93542218, 0.509032726, -3.57361269, -2.82580233,
|
||||
-4.49954033, 2.91235542},
|
||||
{-4.37938213, 4.78577232, 2.03453469, 5.48564529, -1.05589461, -1.65940428,
|
||||
4.0130887, 5.26074123, 4.67537832, 0.791350365, 6.3880868, 2.50402451,
|
||||
7.6603322, -3.16343474, -2.71949649, 4.61576128, 1.3817997}},
|
||||
{{0.289200783, 7.06031752, -1.15099299, -5.29136801, -1.343642, -8.36283112,
|
||||
4.13158274, -1.93137062, 3.16199875, 2.21854591, 2.18270063, 0.77002573,
|
||||
6.90393353, -0.644045949, -5.62211609, -1.09085155, 1.07821059},
|
||||
{-3.04716778, -2.52233481, -5.99031925, 2.80152273, 0.340899587,
|
||||
0.667474508, -2.39674735, 8.83768654, -5.45613146, -1.55994594, -2.216362,
|
||||
1.49354, -4.27255821, -9.05310917, 5.90691471, -1.29772806, -8.50278}},
|
||||
{{-3.1383903, -7.71573353, 3.38072681, 6.07642221, -2.39587545, -7.84178352,
|
||||
-1.60108304, -8.6121521, -5.151721, 4.17612457, -2.86532378, 1.64645958,
|
||||
-0.37970829, -4.34561253, -0.454322815, 0.331385136, -5.74550819},
|
||||
{4.77026033, -5.51171303, -7.38155365, -5.38462543, 2.95842505, 5.18372536,
|
||||
0.521988213, 7.23966122, -4.90852165, 7.18465281, 2.99289083, 10.0519466,
|
||||
-2.09695673, 7.34368706, -2.40495348, 3.61603308, 0.131510735}},
|
||||
};
|
||||
|
||||
// Layer 0, QHead 0
|
||||
const float kGoldenQ[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
||||
{{-0.574401975, 0.370210886, -0.426894158, -0.543187439, -0.0266762674,
|
||||
-0.177960411, -0.00839618221, 0.411925405, 0.536462784, 0.528389931,
|
||||
-0.499812007, -0.123897657, -0.0170236826, 0.266041577, -0.0781469196,
|
||||
-0.44081074, 0.185976267},
|
||||
{0.270543516, -0.109283224, -0.58602041, -0.358663559, -0.393124342,
|
||||
-0.0895933211, -0.632167816, 0.386703, 0.314152211, 0.0554139167,
|
||||
0.0241559595, -0.194484815, 0.143893063, 0.103837147, -0.384245932,
|
||||
-0.00418212265, 0.385817379}},
|
||||
{{-0.0331106335, -0.100827977, 0.322449774, 0.225943685, -0.384854138,
|
||||
-0.208085626, 0.0206767023, 0.287796348, -0.139513299, 0.255447835,
|
||||
-0.0845065042, -0.0619940236, 0.477489054, 0.517492294, -0.0172665715,
|
||||
-0.0302075297, 0.365989387},
|
||||
{-0.0266781822, -0.453293771, 0.560033202, 0.105156079, -0.35259968,
|
||||
0.711447716, -0.253611088, 0.0487165749, -0.086192511, -0.0338740349,
|
||||
-0.655441046, 0.00413730741, -0.510472536, -0.0748229772, -0.29113093,
|
||||
-0.0432077348, 0.09223634}},
|
||||
{{-0.321974993, -0.466039479, 0.207254037, -0.126807183, -0.192775592,
|
||||
-0.0953654051, 0.209789664, 0.405356169, -0.00627984107, -0.0590961352,
|
||||
0.0907663852, -0.190793216, -0.730463982, 0.340142608, -0.295675993,
|
||||
-0.165913597, -0.233714506},
|
||||
{-0.345578939, 0.394073665, 0.299743414, -0.0075177839, -0.288939595,
|
||||
0.127782941, -0.207550645, 0.0655022636, -0.705084503, -0.241842598,
|
||||
0.333820701, 0.217911497, 0.29735288, 0.0147881694, -0.152306199,
|
||||
-0.589594781, -0.373093933}},
|
||||
{{0.216089666, 0.0918798149, 0.0560657382, -0.157523662, -0.00141695142,
|
||||
0.51770103, 0.596379519, -0.271057904, 0.241035417, -0.275827706,
|
||||
0.112851456, 0.026878573, -0.579843462, -0.5116328, 0.192026839,
|
||||
0.125176072, 0.34234497},
|
||||
{-0.0744233653, 0.180814236, 0.170143247, -0.337861449, -0.175804421,
|
||||
0.213403732, -0.173699334, 0.109528325, -0.385727316, 0.109683953,
|
||||
0.475667775, 0.253016889, 0.477347463, 0.111096457, 0.394625545,
|
||||
0.0172286481, -0.357992649}},
|
||||
{{-0.350524545, -0.142550975, -0.212269634, -0.0589753427, -0.434021264,
|
||||
0.384472728, 0.445421219, -0.635599554, -0.246593416, 0.120986834,
|
||||
0.623568773, -0.161932915, -0.702406883, 0.44038102, 0.268234134,
|
||||
0.480264157, 0.103595078},
|
||||
{-0.227436215, 0.357608706, -0.25339672, -0.0683218762, -0.179259315,
|
||||
0.23657614, 0.559984326, 0.165754288, -0.0402980596, -0.101906747,
|
||||
-0.278261065, -0.16327399, 0.235923961, -0.428657919, -0.290629387,
|
||||
0.579215467, -0.0717103705}},
|
||||
{{-0.246389642, -0.266164362, -0.0967710763, -0.4011603, 0.242542207,
|
||||
0.0869855583, 0.20158039, 0.207793877, -0.0875666738, -0.242263764,
|
||||
-0.0462955758, -0.617374003, 0.454443514, 0.207072973, -0.0235372931,
|
||||
-0.0193868056, -0.660622239},
|
||||
{0.703284621, 0.0382430181, 0.43997851, -0.858277559, 0.342218578,
|
||||
0.414044619, 0.403636098, -0.579880178, -1.12243, -0.112913512,
|
||||
0.629238605, -0.0285760984, -0.152203664, -0.088969171, -0.0681343,
|
||||
0.476349175, 0.283238202}},
|
||||
{{0.138267457, 0.483219147, 0.230450034, -0.568304598, 0.204461277,
|
||||
-0.286731184, -0.416590065, -0.483460307, -0.561008453, 0.395195067,
|
||||
0.104367018, -0.196090236, -0.324770749, -0.0881370157, -0.626873195,
|
||||
0.0936089084, 0.262185335},
|
||||
{0.282603383, 0.0723766163, -0.206548154, 0.561849833, 0.482716829,
|
||||
0.135281503, -0.438841999, 0.472577304, -0.346201897, -0.0211652666,
|
||||
-0.0905084163, -0.168639392, -0.154975936, -0.303443581, -0.41771856,
|
||||
0.400717318, 0.426146686}},
|
||||
{{-0.0537007451, -0.227346331, -0.2871463, 0.247746795, -0.0975416005,
|
||||
-0.0123391449, 0.0612513907, -0.374673814, 0.283457696, 0.40945363,
|
||||
0.137944818, -0.0119741419, 0.775918365, -0.308365196, 0.230615795,
|
||||
-0.440364927, 0.218536288},
|
||||
{0.0688965544, -0.149037778, -0.246169299, 0.0599289536, -0.456733435,
|
||||
0.0808929354, 0.115154952, 0.0997388735, -0.408117741, 0.576600909,
|
||||
-0.193775773, 0.0340575948, -0.29254055, 0.695465446, 0.373336494,
|
||||
0.421431482, 0.00197479129}},
|
||||
{{0.402076721, -0.118151993, 0.542394996, 0.0382412486, -0.614983976,
|
||||
0.28617692, 0.318540633, -0.299300969, -0.177486539, 0.394140214,
|
||||
0.0644133314, -0.0321308076, 0.671587527, -0.0173831787, -0.219400048,
|
||||
-0.340277791, 0.5130288},
|
||||
{0.105372488, -0.145784974, 0.0695323348, -0.106080391, -0.755512118,
|
||||
0.975362539, -0.15056029, 0.58882606, -0.059625227, -0.810613,
|
||||
-0.321623206, 0.193939567, 0.0340242684, -0.626081824, 0.109950632,
|
||||
-0.141072854, 0.0177994221}},
|
||||
{{0.243249148, 0.0904035419, -0.472183734, -0.176162, 0.314925164,
|
||||
-0.191137731, 0.492265761, -0.0120046511, 0.824757636, 0.298175,
|
||||
0.148151726, -0.0197859108, -0.64297086, 0.432318538, -0.555079758,
|
||||
0.101636633, 0.155741245},
|
||||
{0.0523641109, 0.224086404, 0.0143201668, 0.0090854, 0.304901183,
|
||||
-0.391372293, 0.267655343, 0.117368169, 0.645064473, 0.336050332,
|
||||
-0.282133281, -0.231817603, 0.376230389, -0.575031936, -0.628365576,
|
||||
0.484799922, 0.0824087635}},
|
||||
};
|
||||
|
||||
void RunAttentionTest(AttentionImpl attention_impl) {
|
||||
TestState state;
|
||||
TestModelState model_state(state);
|
||||
TestAttentionState attention_state(state, model_state, kNumTokens,
|
||||
kQBatchSize, attention_impl);
|
||||
|
||||
GemmaAttention(attention_state.tokens.size(), 0, model_state.layer,
|
||||
attention_state.attention, *attention_state.qbatch, state.env,
|
||||
AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16));
|
||||
|
||||
CompareAttSumsWithGolden(attention_state.attention, kGoldenAttSums);
|
||||
CompareKVCacheWithGolden(model_state.config,
|
||||
hwy::Span<KVCache>(attention_state.kv_caches.data(),
|
||||
attention_state.kv_caches.size()),
|
||||
/*layer=*/0, /*kv_head=*/0, kGoldenK, kGoldenV);
|
||||
CompareQVecsWithGolden(model_state.config, attention_state.attention,
|
||||
/*q_head=*/0, kGoldenQ);
|
||||
}
|
||||
|
||||
void TestGemmaAttentionOld() { RunAttentionTest(AttentionImpl::kOld); }
|
||||
|
||||
void TestGemmaAttentionFlash() { RunAttentionTest(AttentionImpl::kFlash); }
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(AttentionTest);
|
||||
HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionOld);
|
||||
HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionFlash);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
|||
ThreadingArgs threading_args;
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
|
||||
LoaderArgs loader(tokenizer_path, weights_path);
|
||||
LogDebug("LoaderArgs created");
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args);
|
||||
|
||||
// Initialize cached args
|
||||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
inference_args.Init();
|
||||
inference_args.max_generated_tokens = max_generated_tokens;
|
||||
inference_args.temperature = 0.7f;
|
||||
inference_args.top_k = 1;
|
||||
inference_args.deterministic = false;
|
||||
args.inference.max_generated_tokens = max_generated_tokens;
|
||||
args.inference.temperature = 0.7f;
|
||||
args.inference.top_k = 1;
|
||||
args.inference.deterministic = false;
|
||||
|
||||
ss.str("");
|
||||
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||
<< ", temperature: " << inference_args.temperature
|
||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||
<< (inference_args.deterministic ? "true" : "false");
|
||||
<< ", temperature: " << args.inference.temperature
|
||||
<< ", top_k: " << args.inference.top_k << ", deterministic: "
|
||||
<< (args.inference.deterministic ? "true" : "false");
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
return new GemmaContext(loader, inference_args, threading_args,
|
||||
max_generated_tokens);
|
||||
return new GemmaContext(args, max_generated_tokens);
|
||||
}
|
||||
|
||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||
const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args,
|
||||
int max_generated_tokens)
|
||||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
ctx(threading_args),
|
||||
GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens)
|
||||
: args(args),
|
||||
ctx(args.threading),
|
||||
matmul_env(ctx),
|
||||
active_conversation_name("default"),
|
||||
model(loader, inference_args, matmul_env.ctx) {
|
||||
model(args, matmul_env.ctx) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
|
|
@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// set up runtime config
|
||||
TimingInfo timing_info = {};
|
||||
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
||||
.use_spinning = threading_args.spin};
|
||||
inference_args.CopyTo(runtime_config);
|
||||
.use_spinning = args.threading.spin};
|
||||
args.inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
const ModelConfig& model_config = model.Config();
|
||||
|
|
@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
if (!args.inference.multiturn ||
|
||||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
|||
|
||||
class GemmaContext {
|
||||
private:
|
||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_generated_tokens);
|
||||
GemmaContext(const GemmaArgs& args, int max_generated_tokens);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
|
|
@ -81,37 +80,37 @@ class GemmaContext {
|
|||
|
||||
// Set max generated tokens
|
||||
void SetMaxGeneratedTokens(size_t value) {
|
||||
inference_args.max_generated_tokens = value;
|
||||
args.inference.max_generated_tokens = value;
|
||||
LogDebug("Setting max_generated_tokens to configured value");
|
||||
}
|
||||
|
||||
// Set multiturn flag (0 = disabled, 1 = enabled)
|
||||
void SetMultiturn(int value) {
|
||||
inference_args.multiturn = value;
|
||||
args.inference.multiturn = value;
|
||||
LogDebug("Setting multiturn to configured value");
|
||||
}
|
||||
|
||||
// Set temperature for token generation
|
||||
void SetTemperature(float value) {
|
||||
inference_args.temperature = value;
|
||||
args.inference.temperature = value;
|
||||
LogDebug("Setting temperature to configured value");
|
||||
}
|
||||
|
||||
// Set top_k parameter for sampling
|
||||
void SetTopK(int value) {
|
||||
inference_args.top_k = value;
|
||||
args.inference.top_k = value;
|
||||
LogDebug("Setting top_k to configured value");
|
||||
}
|
||||
|
||||
// Set deterministic flag
|
||||
void SetDeterministic(bool value) {
|
||||
inference_args.deterministic = value;
|
||||
args.inference.deterministic = value;
|
||||
LogDebug("Setting deterministic flag to configured value");
|
||||
}
|
||||
|
||||
// Set prefill_tbatch_size
|
||||
void SetPrefillTbatchSize(size_t value) {
|
||||
inference_args.prefill_tbatch_size = value;
|
||||
args.inference.prefill_tbatch_size = value;
|
||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||
}
|
||||
|
||||
|
|
@ -175,7 +174,7 @@ class GemmaContext {
|
|||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
|
|
@ -193,7 +192,7 @@ class GemmaContext {
|
|||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
model.Config(), args.inference, ctx.allocator);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -274,8 +273,7 @@ class GemmaContext {
|
|||
std::vector<int> token_buffer;
|
||||
|
||||
// Cached args (remain global for the context)
|
||||
InferenceArgs inference_args;
|
||||
ThreadingArgs threading_args;
|
||||
GemmaArgs args;
|
||||
ThreadingContext ctx;
|
||||
MatMulEnv matmul_env;
|
||||
|
||||
|
|
|
|||
|
|
@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() {
|
|||
config.display_name = "Gemma3_1B";
|
||||
config.model = Model::GEMMA3_1B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.use_global_timescale = true;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||
config.max_seq_len = 32 * 1024;
|
||||
|
|
@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() {
|
|||
config.display_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.use_global_timescale = true;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() {
|
|||
config.display_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.use_global_timescale = true;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() {
|
|||
config.display_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.use_global_timescale = true;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) {
|
|||
}
|
||||
|
||||
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
|
||||
if (IsPaliGemma(model)) {
|
||||
const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping;
|
||||
|
||||
// For models with a fixed wrapping mode, ignore user override.
|
||||
if (config_wrapping == PromptWrapping::PALIGEMMA ||
|
||||
config_wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for model %s.",
|
||||
ModelPrefix(model));
|
||||
}
|
||||
return PromptWrapping::PALIGEMMA;
|
||||
return config_wrapping;
|
||||
}
|
||||
if (IsVLM(model)) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
|
||||
}
|
||||
return PromptWrapping::GEMMA_VLM;
|
||||
}
|
||||
// Default to IT unless --wrapping=0.
|
||||
|
||||
// For other models, default to IT unless --wrapping=0 is passed.
|
||||
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
|
||||
: PromptWrapping::GEMMA_IT;
|
||||
}
|
||||
|
|
@ -674,7 +678,9 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
|||
return Model::GEMMA3_270M;
|
||||
|
||||
case 26:
|
||||
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
|
||||
if (layer_types & (kDeducedViT|kDeducedKqNorm)) {
|
||||
return Model::GEMMA3_1B;
|
||||
}
|
||||
return Model::GEMMA2_2B;
|
||||
case 27:
|
||||
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
|
||||
|
|
@ -706,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
|||
}
|
||||
}
|
||||
|
||||
AttentionImpl GetAttentionImpl(const std::string& impl) {
|
||||
if (impl == "old") return AttentionImpl::kOld;
|
||||
if (impl == "flash") return AttentionImpl::kFlash;
|
||||
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
|
||||
return AttentionImpl::kOld;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -80,6 +80,38 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
enum class AttentionImpl {
|
||||
kOld,
|
||||
kFlash,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
AttentionImpl GetAttentionImpl(const std::string& impl);
|
||||
|
||||
/*
|
||||
* Returns a bitmask of flags to pass to attention functions based on the
|
||||
* attention implementation selected.
|
||||
*
|
||||
* If `hwy_native_dot_bf16` is true, the function will use the old attention
|
||||
* implementation, ignoring `impl`.
|
||||
*
|
||||
* `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16
|
||||
* macro is not available outside of highway instrumented translation units and
|
||||
* cannot be made accessible from .h files.
|
||||
*/
|
||||
static inline int AttentionImplToFlags(AttentionImpl impl,
|
||||
int hwy_native_dot_bf16) {
|
||||
if (hwy_native_dot_bf16) return kAttentionUseOld;
|
||||
|
||||
switch (impl) {
|
||||
case AttentionImpl::kOld:
|
||||
return kAttentionUseOld;
|
||||
case AttentionImpl::kFlash:
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
|
|
@ -184,13 +216,6 @@ enum class Model {
|
|||
// in Specifier and thus does not change.
|
||||
const char* ModelPrefix(Model model);
|
||||
|
||||
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
|
||||
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
|
||||
static inline bool IsVLM(Model model) {
|
||||
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
|
||||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
|
||||
}
|
||||
|
||||
static inline bool IsPaliGemma(Model model) {
|
||||
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||
model == Model::PALIGEMMA2_10B_224 ||
|
||||
|
|
@ -383,6 +408,8 @@ struct ModelConfig : public IFields {
|
|||
|
||||
internal.VisitFields(visitor);
|
||||
|
||||
visitor(use_global_timescale);
|
||||
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
|
|
@ -481,6 +508,7 @@ struct ModelConfig : public IFields {
|
|||
std::vector<std::string> scale_base_names;
|
||||
|
||||
InternalModelConfig internal;
|
||||
bool use_global_timescale = false; // for Gemma 3
|
||||
};
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
|
|
@ -489,6 +517,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
|
|||
enum DeducedLayerTypes {
|
||||
kDeducedViT = 2,
|
||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||
kDeducedKqNorm = 8,
|
||||
};
|
||||
|
||||
// layer_types is one or more of `DeducedLayerTypes`.
|
||||
|
|
|
|||
|
|
@ -17,12 +17,18 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/base.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -30,7 +36,6 @@
|
|||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -59,7 +64,7 @@ static constexpr size_t kNFx8HTileSize = 8;
|
|||
// q has shape [batch, qbatch][head, qkv_dim].
|
||||
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
|
||||
// possible consecutive elements have the same KV.
|
||||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
|
||||
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||
// Group floats by the number of floats in a cache line.
|
||||
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
||||
|
|
@ -70,12 +75,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
for (size_t lane = 0; lane < kNF; ++lane) {
|
||||
size_t q_row = task * kNF + lane;
|
||||
if (q_row >= q_t.Rows()) break;
|
||||
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
|
||||
BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
|
||||
for (size_t qi = 0; qi < qbatch_size; ++qi) {
|
||||
for (size_t h = 0; h < num_heads; ++h) {
|
||||
for (size_t b = 0; b < batch_size; ++b) {
|
||||
qt_row[(qi * num_heads + h) * batch_size + b] =
|
||||
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -84,45 +90,48 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
{
|
||||
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||
// Better than kFlat.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
|
||||
ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates q in place for RMSNorm and positional encoding.
|
||||
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||
MatPtrT<KV_t>& q, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
MatPtrT<float>& q,
|
||||
const MatPtr& query_norm_scale,
|
||||
const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
ThreadingContext& ctx) {
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const float query_scale = activations.query_scale;
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||
size_t qi = div_qbatch.Remainder(task);
|
||||
size_t batch_idx = div_qbatch.Divide(task);
|
||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT q_row =
|
||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
const size_t pos =
|
||||
qbatch.Pos(qi) + batch_idx + offset;
|
||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
|
||||
pos, query_scale);
|
||||
PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
|
||||
query_scale);
|
||||
}
|
||||
};
|
||||
{
|
||||
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||
// very lightweight.
|
||||
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
|
||||
func);
|
||||
}
|
||||
|
|
@ -152,15 +161,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
|||
|
||||
// Calculates the complete attention outputs for a single row of q.
|
||||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
// TODO: Mixed-mode can be further improved for Turin: we can demote right
|
||||
// before we do the dot product instruction, rather than promote both to f32.
|
||||
// But some potential accuracy loss there, needs evaluation first.
|
||||
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
m = cap * std::tanh(m / cap);
|
||||
|
|
@ -170,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
|
||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||
v.Row(pos_mod), v.Cols(), att_out);
|
||||
}
|
||||
|
|
@ -180,25 +194,27 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
// the dot products of NF rows of Q for a single K timestep.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const size_t k_pos, const MatPtrT<KV_t>& q,
|
||||
const size_t k_pos, const MatPtrT<BF16>& q,
|
||||
const MatPtrT<KV_t>& k) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
|
||||
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
|
||||
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
}
|
||||
return hn::LoadU(df, results);
|
||||
}
|
||||
|
||||
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
|
||||
// precision.
|
||||
// Returns an NF Q rows by 8 K rows tile of Q.K dot products.
|
||||
// This is the result of NF rows of Q against 8 K timesteps, with positions
|
||||
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
|
||||
// consecutive elements, and other columns by adding q_stride.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {
|
||||
void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
|
||||
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
|
||||
constexpr size_t kHTileSize = kNFx8HTileSize;
|
||||
sum0 = hn::Zero(df);
|
||||
sum1 = hn::Zero(df);
|
||||
|
|
@ -209,11 +225,16 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
|||
sum6 = hn::Zero(df);
|
||||
sum7 = hn::Zero(df);
|
||||
const float* HWY_RESTRICT k_row[kHTileSize];
|
||||
for (int i = 0; i < kHTileSize; ++i) {
|
||||
for (size_t i = 0; i < kHTileSize; ++i) {
|
||||
k_row[i] = k.Row(k_pos[i]);
|
||||
}
|
||||
|
||||
const hn::Rebind<BF16, DF> dbfh;
|
||||
using VBF = hn::Vec<decltype(dbfh)>;
|
||||
|
||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||
VF q_vec = hn::Load(df, q);
|
||||
const VBF q_vec_bf = hn::Load(dbfh, q);
|
||||
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
|
||||
VF k_0 = hn::Set(df, k_row[0][i]);
|
||||
sum0 = hn::MulAdd(q_vec, k_0, sum0);
|
||||
VF k_1 = hn::Set(df, k_row[1][i]);
|
||||
|
|
@ -266,31 +287,30 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
|||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
||||
constexpr int kHTileSize = kNFx8HTileSize;
|
||||
constexpr size_t kHTileSize = kNFx8HTileSize;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
using DI = hn::ScalableTag<uint32_t>;
|
||||
const DI di;
|
||||
using VI = hn::Vec<DI>;
|
||||
const int kVTileSize = hn::Lanes(df);
|
||||
for (int i = 0; i < kVTileSize; ++i) {
|
||||
const size_t kVTileSize = hn::Lanes(df);
|
||||
for (size_t i = 0; i < kVTileSize; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||
}
|
||||
VI lasts = hn::LoadU(di, last_pos);
|
||||
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
||||
VF old_d = hn::Zero(df);
|
||||
const float* HWY_RESTRICT qT_row = qT.Row(0);
|
||||
const BF16* HWY_RESTRICT qT_row = qT.Row(0);
|
||||
const size_t qT_stride = qT.Stride();
|
||||
size_t position = start_pos;
|
||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||
|
|
@ -299,8 +319,7 @@ void TileFlashAttention(
|
|||
k_pos[i] = activations.div_seq_len.Remainder(position + i);
|
||||
}
|
||||
VF x0, x1, x2, x3, x4, x5, x6, x7;
|
||||
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
|
||||
x7);
|
||||
QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||
VF cap = hn::Set(df, activations.config.att_cap);
|
||||
|
|
@ -374,7 +393,7 @@ void TileFlashAttention(
|
|||
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
||||
// given by k_offsets[0..NF].
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
||||
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
|
|
@ -389,13 +408,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
|||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
|
||||
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
|
||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
|
||||
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
|
||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
|
||||
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
|
||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
|
||||
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
|
||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||
}
|
||||
}
|
||||
|
|
@ -410,23 +429,202 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
float scale = old_d * std::exp(old_max - m);
|
||||
old_d = hn::ReduceSum(df, x) + scale;
|
||||
old_max = m;
|
||||
float one_over_d = 1.0f / old_d;
|
||||
if (old_d > 0.0f) {
|
||||
const float one_over_d = 1.0f / old_d;
|
||||
scale *= one_over_d;
|
||||
x = hn::Mul(x, hn::Set(df, one_over_d));
|
||||
} else {
|
||||
scale = 0.0f;
|
||||
x = hn::Zero(df);
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention4(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
// Reduces each of x and stores in following lanes of max (tested with float32)
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
F reducer) {
|
||||
const DF4 df4;
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
|
||||
HWY_ALIGN T x_transposed[4 * kMaxLanes];
|
||||
hn::StoreInterleaved4<DF>(x_0, x_1, x_2, x_3, df, x_transposed);
|
||||
VF4 result = hn::Load(df4, x_transposed);
|
||||
for (int i = 1; i < kLanes; ++i) {
|
||||
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
|
||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
||||
float* HWY_RESTRICT scales) {
|
||||
using DF4 = hn::CappedTag<float, 4>;
|
||||
const DF4 df4;
|
||||
using VF4 = hn::Vec<DF4>;
|
||||
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
|
||||
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
|
||||
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
|
||||
max_0 = hn::Max(x_0_p0, x_0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
max_1 = hn::Max(x_1_p0, x_1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
max_2 = hn::Max(x_2_p0, x_2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
max_3 = hn::Max(x_3_p0, x_3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries == 1) {
|
||||
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
|
||||
} else {
|
||||
new_max = Reduce4(df, max_0, max_1, max_2, max_3,
|
||||
[](auto a, auto b) { return hn::Max(a, b); });
|
||||
}
|
||||
if (att_cap > 0.0f) {
|
||||
VF4 cap = hn::Set(df4, att_cap);
|
||||
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
|
||||
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
|
||||
}
|
||||
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
|
||||
old_max_vf = hn::LoadU(df4, old_max);
|
||||
new_max = hn::Max(new_max, old_max_vf);
|
||||
// TODO figure out what was wrong with broadcasts and change to that.
|
||||
HWY_ALIGN float tmp_max[4];
|
||||
hn::Store(new_max, df4, tmp_max);
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[0]);
|
||||
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0));
|
||||
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[1]);
|
||||
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
|
||||
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[2]);
|
||||
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
|
||||
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[3]);
|
||||
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
|
||||
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
|
||||
}
|
||||
VF4 old_d_vf = hn::Set(df4, 0.0f);
|
||||
old_d_vf = hn::LoadU(df4, old_d);
|
||||
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
|
||||
|
||||
hn::StoreU(new_max, df4, old_max);
|
||||
|
||||
VF4 x_sum = hn::Zero(df4);
|
||||
if constexpr (kNumQueries == 1) {
|
||||
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
|
||||
} else {
|
||||
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
|
||||
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
|
||||
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
|
||||
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
|
||||
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
|
||||
[](auto a, auto b) { return hn::Add(a, b); });
|
||||
}
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
const VF4 zero4 = hn::Zero(df4);
|
||||
const VF4 one_over_d =
|
||||
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
|
||||
float tmp_one_over_d[4];
|
||||
hn::Store(one_over_d, df4, tmp_one_over_d);
|
||||
hn::Store(old_d_vf, df4, old_d);
|
||||
scale = hn::Mul(scale, one_over_d);
|
||||
hn::Store(scale, df4, scales);
|
||||
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
|
||||
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
|
||||
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
|
||||
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
|
||||
} else {
|
||||
x_0_p0 = zero;
|
||||
x_0_p1 = zero;
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
|
||||
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
|
||||
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
|
||||
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
|
||||
} else {
|
||||
x_1_p0 = zero;
|
||||
x_1_p1 = zero;
|
||||
}
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
|
||||
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
|
||||
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
|
||||
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
|
||||
} else {
|
||||
x_2_p0 = zero;
|
||||
x_2_p1 = zero;
|
||||
}
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
|
||||
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
|
||||
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
|
||||
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
|
||||
} else {
|
||||
x_3_p0 = zero;
|
||||
x_3_p1 = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implements flash attention for a strip of 4 query vectors.
|
||||
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
|
||||
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
|
||||
// by NF timesteps in K for efficiency while timesteps between `min_last_pos +
|
||||
// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
|
||||
// values within the strip.
|
||||
// (*) Actually, it only iterates through
|
||||
// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
|
||||
// computation can, for obvious reasons, only process an integer number of
|
||||
// tiles.
|
||||
//
|
||||
// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
|
||||
// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
|
||||
// vectors to be processed in this tile.
|
||||
// @param k Key matrix [seq_len, qkv_dim] from KV cache.
|
||||
// @param start_pos The first token position in the KV cache to attend to.
|
||||
// @param last_pos An array of 4 indices giving the last token position
|
||||
// (inclusive) that each of the 4 queries may attend to.
|
||||
// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
|
||||
// position can be processed efficiently in batches.
|
||||
// @param max_last_pos The maximum value in `last_pos`. Timesteps between
|
||||
// `min_last_pos + 1` and this position are processed individually to
|
||||
// respect each query's `last_pos` limit.
|
||||
// @param v Value matrix [seq_len, qkv_dim] from KV cache.
|
||||
// @param layer_idx The index of the current transformer layer.
|
||||
// @param activations Attention configurations and buffers.
|
||||
// @param att_out Output buffer for attention results.
|
||||
// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
|
||||
// vectors.
|
||||
// @param ctx Threading context.
|
||||
// @param worker Worker thread index.
|
||||
Tile4FlashState TileFlashAttention4(
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
|
|
@ -440,14 +638,7 @@ void TileFlashAttention4(
|
|||
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||
}
|
||||
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_d0 = 0.0f;
|
||||
float old_d1 = 0.0f;
|
||||
float old_d2 = 0.0f;
|
||||
float old_d3 = 0.0f;
|
||||
Tile4FlashState state;
|
||||
size_t position = start_pos;
|
||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||
int32_t k_offsets[kMaxNF];
|
||||
|
|
@ -467,46 +658,62 @@ void TileFlashAttention4(
|
|||
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
|
||||
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
|
||||
}
|
||||
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
|
||||
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
|
||||
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
|
||||
scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
|
||||
state.row_states[0].d);
|
||||
scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
|
||||
state.row_states[1].d);
|
||||
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
|
||||
state.row_states[2].d);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
|
||||
state.row_states[3].d);
|
||||
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
|
||||
out_offsets, v.Cols());
|
||||
position += kHTileSize;
|
||||
}
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t qkv_dim = k.Cols();
|
||||
|
||||
while (position <= max_last_pos) {
|
||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||
if (position <= last_pos[0]) {
|
||||
// Past the last position, x0 doesn't count.
|
||||
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
||||
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap,
|
||||
state.row_states[0].max, state.row_states[0].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[0]);
|
||||
}
|
||||
if (position <= last_pos[1]) {
|
||||
// Past the last position, x1 doesn't count.
|
||||
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
||||
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap,
|
||||
state.row_states[1].max, state.row_states[1].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[1]);
|
||||
}
|
||||
if (position <= last_pos[2]) {
|
||||
// Past the last position, x2 doesn't count.
|
||||
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
||||
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap,
|
||||
state.row_states[2].max, state.row_states[2].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[2]);
|
||||
}
|
||||
if (position <= last_pos[3]) {
|
||||
// Past the last position, x3 doesn't count.
|
||||
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
||||
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
|
||||
k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap,
|
||||
state.row_states[3].max, state.row_states[3].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[3]);
|
||||
}
|
||||
++position;
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
// Rounds n to a number that can be used as the number of Q rows in a tile
|
||||
|
|
@ -589,14 +796,25 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
|||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||
// 3 modes to use for best efficiency.
|
||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
const size_t layer_idx, const MatPtr& query_norm_scale,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
||||
layer, activations, ctx);
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||
query_norm_scale, layer_idx, activations, ctx);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
// Compress q to q_bf.
|
||||
ParallelFor(
|
||||
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t row, size_t worker) {
|
||||
CompressPerThread tls;
|
||||
const hn::ScalableTag<float> df;
|
||||
CompressTraits<BF16>::Compress(
|
||||
df, activations.q.Row(row), activations.q.Cols(), tls,
|
||||
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
|
||||
});
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
|
|
@ -684,14 +902,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
size_t last = pos;
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
if (prefix_end > 0 && prefix_end - 1 > last) {
|
||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||
// last_pos in `TileFlashAttention` is inclusive.
|
||||
last = prefix_end - 1;
|
||||
}
|
||||
last_pos[offset] = last;
|
||||
min_last_pos = HWY_MIN(min_last_pos, last);
|
||||
max_last_pos = HWY_MAX(max_last_pos, last);
|
||||
q_offsets[offset] =
|
||||
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
|
||||
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
|
||||
activations.q_bf.Row(0);
|
||||
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
|
||||
activations.att_out.Row(0);
|
||||
const size_t kv_index = head / kHeadGroups;
|
||||
|
|
@ -719,8 +937,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
// To avoid duplicating the code to setup K and V, the call to
|
||||
// TileFlashAttention is inside the loop over tasks, even though it
|
||||
// handles all rows in the task at once.
|
||||
StridedView<float> qT =
|
||||
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||
StridedView<BF16> qT =
|
||||
StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||
activations.q_T.Stride());
|
||||
if (kVTileSize == kNF) {
|
||||
// We can still use TileFlashAttention even if we didn't transpose Q
|
||||
|
|
@ -730,14 +948,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
|
||||
// use qT and some might not, which is why the more general condition
|
||||
// is used above to catch all cases where qT will be used.
|
||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else if (kVTileSize == 4) {
|
||||
TileFlashAttention4(activations.q, q_offsets, k,
|
||||
TileFlashAttention4(activations.q_bf, q_offsets, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else {
|
||||
HWY_UNREACHABLE;
|
||||
|
|
@ -745,8 +963,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
break;
|
||||
} else {
|
||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||
layer_idx, layer, activations,
|
||||
activations.q_bf.Row(0) + q_offsets[offset], k, v,
|
||||
layer_idx, activations,
|
||||
activations.att_out.Row(0) + out_offsets[offset],
|
||||
ctx, worker);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
|
|
@ -28,27 +31,36 @@ namespace gcpp {
|
|||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
|
||||
MatPtrT<KV_t>& q, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
ThreadingContext& ctx); \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const float* HWY_RESTRICT q, \
|
||||
const BF16* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
|
||||
ThreadingContext& ctx, const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
\
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -24,6 +26,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/test_util.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -109,11 +112,14 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const LayerConfig& layer_config = config.layer_configs[0];
|
||||
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
|
||||
InferenceArgs inference_args;
|
||||
inference_args.attention_impl = "flash";
|
||||
RuntimeConfig runtime_config;
|
||||
inference_args.CopyTo(runtime_config);
|
||||
KVCache kv_cache(config, inference_args, ctx.allocator);
|
||||
MatMulEnv env(ctx);
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||
Activations activations(runtime_config, config,
|
||||
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||
env.ctx, env.row_ptrs);
|
||||
std::vector<int> tokens(kOuter);
|
||||
std::iota(tokens.begin(), tokens.end(), 1);
|
||||
PromptTokens prompt(tokens);
|
||||
|
|
@ -122,8 +128,10 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
||||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention(config, layer_config, batch_size, kOuter,
|
||||
ctx.allocator, row_ptrs);
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, runtime_config, ctx.allocator,
|
||||
row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
ASSERT_EQ(qkv_dim, kInner);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
|
|
@ -145,7 +153,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
SetMat(h + layer_config.heads * 2, v);
|
||||
}
|
||||
SetMat(1, attention.q);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
|
||||
qbatch, ctx);
|
||||
// Copy the output to saved_att to allow for comparison.
|
||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||
SetMat(1, attention.q);
|
||||
|
|
@ -158,8 +167,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
total_tasks, target_parallelism);
|
||||
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||
target_parallelism, kNF, kVTileSize);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
||||
qbatch, ctx);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
|
||||
attention, qbatch, ctx);
|
||||
AssertClose(attention.att_out, *saved_att);
|
||||
ctx.profiler.PrintResults();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// State for computing softmax in a streaming ("online") manner,
|
||||
// avoiding large intermediate values by subtracting the running maximum.
|
||||
// For a sequence x_1, ..., x_n:
|
||||
// m_i = max(m_{i-1}, x_i)
|
||||
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
|
||||
// softmax_i = exp(x_i - m_i) / d_i
|
||||
struct OnlineSoftmaxState {
|
||||
// Maximum logit value encountered so far.
|
||||
float max = -std::numeric_limits<float>::max() / 2.0f;
|
||||
// Sum of exponentials scaled by exp(-max).
|
||||
float d = 0.0f;
|
||||
};
|
||||
|
||||
static constexpr size_t kVTileSize4 = 4;
|
||||
|
||||
struct Tile4FlashState {
|
||||
OnlineSoftmaxState row_states[kVTileSize4];
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/mat.h"
|
||||
|
|
@ -70,7 +71,7 @@ template <class Mat>
|
|||
void ActivationBatched(
|
||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||
Parallelism parallelism = Parallelism::kFlat) {
|
||||
using T = typename Mat::T;
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||
|
|
@ -115,7 +116,7 @@ template <class Mat1, class Mat2>
|
|||
HWY_NOINLINE void ActivationBatched(
|
||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||
Parallelism parallelism = Parallelism::kFlat) {
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
|
|
@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
|
||||
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
||||
|
||||
activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out,
|
||||
env.ctx);
|
||||
|
||||
#if GEMMA_FUSED_FFN
|
||||
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||
StridedViewBF C2, size_t worker) {
|
||||
|
|
@ -179,8 +183,31 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
env.ctx);
|
||||
#endif
|
||||
|
||||
activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||
|
||||
activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx);
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
233
gemma/gemma.cc
233
gemma/gemma.cc
|
|
@ -18,12 +18,16 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/zones.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "util/zones.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
|
|
@ -73,10 +77,12 @@ namespace HWY_NAMESPACE {
|
|||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
// TODO: remove flag to enable FlashAttention.
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
|
||||
GemmaAttention(
|
||||
num_tokens, layer_idx, layer, activations.attention, qbatch, env,
|
||||
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -355,6 +361,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
||||
token, 0.0f);
|
||||
qbatch.MutablePos(qi) = pos_in_prompt;
|
||||
} else {
|
||||
// This prevents the kv cache of eos_id to be written to last prefilled
|
||||
// token.
|
||||
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
|
|
@ -426,7 +436,7 @@ static void SampleAndStream(const ModelConfig& config,
|
|||
timing_info.NotifyGenerated(non_eos.Count());
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
||||
Parallelism::kFlat, qbatch.Size(), env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kSampleAndStream,
|
||||
[&](size_t qi, size_t worker) {
|
||||
if (!non_eos.Get(qi)) return;
|
||||
|
|
@ -484,12 +494,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
|||
};
|
||||
}
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
size_t max_prompt_size = 0;
|
||||
bool all_prefix_end_are_zero = true;
|
||||
size_t total_prefill_tokens = 0; // only for throughput stats.
|
||||
|
|
@ -511,14 +520,14 @@ static void GenerateT(const ModelConfig& config,
|
|||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||
}
|
||||
if (max_prompt_size >= seq_len) {
|
||||
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
|
||||
max_prompt_size);
|
||||
if (max_prompt_size > seq_len) {
|
||||
HWY_ABORT(
|
||||
"max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
|
||||
"that.",
|
||||
max_prompt_size, seq_len);
|
||||
}
|
||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
// qi loops anyway.
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
timing_info.prefill_start = hwy::platform::Now();
|
||||
|
|
@ -536,8 +545,21 @@ static void GenerateT(const ModelConfig& config,
|
|||
timing_info.NotifyPrefill(total_prefill_tokens);
|
||||
// queries_pos have been incremented by Prefill.
|
||||
|
||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
if (max_prompt_size + max_gen_steps > seq_len) {
|
||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||
max_prompt_size, max_gen_steps, seq_len);
|
||||
max_gen_steps = seq_len - max_prompt_size;
|
||||
}
|
||||
|
||||
return max_gen_steps;
|
||||
}
|
||||
|
||||
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch,
|
||||
hwy::BitSet4096<>& non_eos,
|
||||
size_t qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
|
||||
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||
|
|
@ -548,11 +570,35 @@ static void GenerateT(const ModelConfig& config,
|
|||
config, runtime_config, qbatch, update_pos, non_eos);
|
||||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
if (max_prompt_size + max_gen_steps > seq_len) {
|
||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||
max_prompt_size, max_gen_steps, seq_len);
|
||||
max_gen_steps = seq_len - max_prompt_size;
|
||||
void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a,
|
||||
ThreadingContext& ctx) {
|
||||
const size_t layer_idx = layer.layer_idx;
|
||||
a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx,
|
||||
kTensorStatsIsWeight);
|
||||
a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx,
|
||||
kTensorStatsIsWeight);
|
||||
a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight);
|
||||
}
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
for (const LayerWeightsPtrs& layer : weights.c_layers) {
|
||||
SetWeightStats(layer, activations, env.ctx);
|
||||
}
|
||||
|
||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi);
|
||||
}
|
||||
|
||||
const SampleFunc sample_token =
|
||||
|
|
@ -567,14 +613,66 @@ static void GenerateT(const ModelConfig& config,
|
|||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
||||
// Same as GenerateT, but uses ContinuousQBatch.
|
||||
static void GenerateTWithContinuousBatching(
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t qbatch_size = runtime_config.decode_qbatch_size;
|
||||
|
||||
QBatch qbatch(0, qbatch_size, all_queries);
|
||||
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
|
||||
|
||||
hwy::BitSet4096<> non_eos;
|
||||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||
|
||||
size_t query_inserted = 0;
|
||||
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
// Continue if qi slot is still processing.
|
||||
if (non_eos.Get(qi)) continue;
|
||||
// Collect the kv_cache from the qi slot in the qbatch to the
|
||||
// available_kv_caches_ in the prefill_batch.
|
||||
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
|
||||
|
||||
// Prefill if no available prefilled queries to insert.
|
||||
if (prefill_batch.ShouldPrefill()) {
|
||||
prefill_batch.SetupNextBatchForPrefill();
|
||||
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
|
||||
prefill_batch, env, timing_info);
|
||||
activations.SetBatchSize(qbatch.Size());
|
||||
}
|
||||
|
||||
// Get the next query to insert to the generate batch.
|
||||
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
|
||||
if (qi_to_insert) {
|
||||
qbatch.Insert(qi_to_insert.value(), qi);
|
||||
query_inserted++;
|
||||
|
||||
non_eos.Set(qi);
|
||||
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
|
||||
qi);
|
||||
}
|
||||
}
|
||||
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||
Activations activations(runtime_config, config,
|
||||
runtime_config.prefill_tbatch_size, kv_cache.SeqLen(),
|
||||
env.ctx, env.row_ptrs);
|
||||
|
||||
AllQueries all_queries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCache>(&kv_cache, 1));
|
||||
|
|
@ -592,16 +690,21 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
TimingInfo& timing_info) {
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size,
|
||||
Activations activations(runtime_config, config, max_batch_size,
|
||||
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||
env.row_ptrs);
|
||||
|
||||
if (runtime_config.use_continuous_batching) {
|
||||
GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
|
||||
activations, all_queries, env, timing_info);
|
||||
} else {
|
||||
for (size_t start = 0; start < all_queries.NumQueries();
|
||||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch,
|
||||
env, timing_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -617,8 +720,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
|
||||
env.row_ptrs);
|
||||
Activations prefill_activations(runtime_config, vit_config, num_tokens,
|
||||
num_tokens, env.ctx, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
|
|
@ -635,17 +738,16 @@ HWY_EXPORT(GenerateSingleT);
|
|||
HWY_EXPORT(GenerateBatchT);
|
||||
HWY_EXPORT(GenerateImageTokensT);
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx)
|
||||
: reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
|
||||
: reader_(args.loader.weights),
|
||||
model_(reader_, args.loader.tokenizer, args.loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference),
|
||||
aes_ctr_engine_(inference.deterministic) {
|
||||
inference_(args.inference),
|
||||
aes_ctr_engine_(args.inference.deterministic) {
|
||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
||||
mat_owners_, ctx);
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
||||
args.inference, mat_owners_, ctx);
|
||||
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
|
@ -698,5 +800,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
|||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
|
||||
: QBatch(0, max_size, queries) {
|
||||
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
|
||||
if (!queries_[i].kv_cache.IsEmpty()) {
|
||||
// Put the kv_cache to the available_kv_caches_ instead; leaving the
|
||||
// kv_cache in the queries_ is very confusing. This simplifies the logic
|
||||
// of kv_cache management.
|
||||
available_kv_caches_.push_back(queries_[i].kv_cache);
|
||||
queries_[i].kv_cache = KVCachePtr();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ContinuousQBatch::ShouldPrefill() const {
|
||||
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
|
||||
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
|
||||
return no_available_to_insert && more_queries_to_prefill;
|
||||
}
|
||||
|
||||
void ContinuousQBatch::SetupNextBatchForPrefill() {
|
||||
start_ = next_to_prefill_;
|
||||
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
query_idx_.clear();
|
||||
query_idx_.reserve(size_);
|
||||
for (size_t i = 0; i < size_; ++i) {
|
||||
const size_t next_query_idx = start_ + i;
|
||||
query_idx_.push_back(next_query_idx);
|
||||
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
|
||||
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
|
||||
available_kv_caches_.pop_back();
|
||||
}
|
||||
next_to_prefill_ += size_;
|
||||
}
|
||||
|
||||
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
|
||||
if (next_to_insert_ == next_to_prefill_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
next_to_insert_++;
|
||||
return next_to_insert_ - 1;
|
||||
}
|
||||
|
||||
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
|
||||
const int query_to_collect = from.QueryIdx(0);
|
||||
// Only collect if the query to collect is not the same as the next query to
|
||||
// insert. This happens at the beginning of each Generate call.
|
||||
if (query_to_collect != next_to_insert_) {
|
||||
// Only clear the KV cache if there are more queries to insert; Otherwise
|
||||
// we get a crash because Transformer will still access that KV cache.
|
||||
if (next_to_insert_ < queries_.NumQueries()) {
|
||||
available_kv_caches_.push_back(from.KV(0));
|
||||
ZeroInit(from.KV(0).kv_cache);
|
||||
from.KV(0) = KVCachePtr();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
147
gemma/gemma.h
147
gemma/gemma.h
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
|
|
@ -26,6 +27,7 @@
|
|||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/query.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/io.h" // Path
|
||||
|
|
@ -38,132 +40,28 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
struct PerQuery {
|
||||
PromptTokens prompt;
|
||||
|
||||
// Position in the KV cache: initially zero for the first turn, or when
|
||||
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
|
||||
size_t mutable_pos;
|
||||
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
|
||||
// which might differ from `prompt.size() - 1` for prefix-LM.
|
||||
size_t initial_pos;
|
||||
// Zero for causal attention, or the end of the prefix for prefix-LM style
|
||||
// attention in Paligemma.
|
||||
size_t prefix_end;
|
||||
|
||||
KVCache& kv_cache;
|
||||
|
||||
// Previous token generated for this query, or the last prompt token. Will be
|
||||
// fed into the next Transformer() call.
|
||||
int prev_token = 0;
|
||||
};
|
||||
|
||||
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
|
||||
struct AllQueries {
|
||||
AllQueries() = default;
|
||||
|
||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches) {
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompt,
|
||||
.mutable_pos = pos,
|
||||
.initial_pos = pos,
|
||||
.prefix_end = prefix_end,
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Batch of queries with initial position set to zero. Causal attention
|
||||
// is requested via empty or all-zero `prefix_end`.
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompts[i],
|
||||
.mutable_pos = 0,
|
||||
.initial_pos = 0,
|
||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Reserve(size_t size) { per_query_.reserve(size); }
|
||||
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
||||
|
||||
size_t NumQueries() const { return per_query_.size(); }
|
||||
|
||||
PerQuery& operator[](size_t query_idx) {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
const PerQuery& operator[](size_t query_idx) const {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PerQuery> per_query_;
|
||||
};
|
||||
|
||||
// View into AllQueries: either a batch of queries, or a single query for use
|
||||
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
|
||||
// reference to AllQueries.
|
||||
class QBatch {
|
||||
// Used for continuous batching.
|
||||
class ContinuousQBatch : public QBatch {
|
||||
public:
|
||||
QBatch(size_t start, size_t max_size, AllQueries& queries)
|
||||
: start_(start),
|
||||
max_size_(max_size),
|
||||
queries_(queries),
|
||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
}
|
||||
ContinuousQBatch(size_t max_size, AllQueries& queries);
|
||||
|
||||
// Returns a single-query view starting at `qi` relative to this batch.
|
||||
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
||||
// Whether we should prefill the next batch, i.e. next_to_insert_ ==
|
||||
// next_to_prefill_.
|
||||
bool ShouldPrefill() const;
|
||||
|
||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||
size_t Size() const { return size_; }
|
||||
// Setup the query_idx_ to point to the next group of queries to prefill.
|
||||
void SetupNextBatchForPrefill();
|
||||
|
||||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||
size_t QueryIdx(size_t qi) const {
|
||||
HWY_DASSERT(qi < size_);
|
||||
return start_ + qi;
|
||||
}
|
||||
// Get the next query to insert to the generate batch.
|
||||
std::optional<size_t> GetNextToInsert();
|
||||
|
||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||
const PromptTokens& Prompt(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prompt;
|
||||
}
|
||||
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t InitialPos(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].initial_pos;
|
||||
}
|
||||
size_t PrefixEnd(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prefix_end;
|
||||
}
|
||||
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
// Collect the kv_cache from QBatch to available_kv_caches_.
|
||||
void MaybeReleaseKV(const QBatch& from);
|
||||
|
||||
private:
|
||||
size_t start_;
|
||||
size_t max_size_;
|
||||
AllQueries& queries_;
|
||||
size_t size_;
|
||||
public:
|
||||
int next_to_prefill_ = 0;
|
||||
int next_to_insert_ = 0;
|
||||
std::vector<KVCachePtr> available_kv_caches_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
|
|
@ -232,11 +130,16 @@ struct TimingInfo {
|
|||
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`.
|
||||
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
|
||||
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
|
||||
Gemma(const GemmaArgs& args, ThreadingContext& ctx);
|
||||
|
||||
// Deprecated prior interface for backwards compatibility.
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx);
|
||||
ThreadingContext& ctx)
|
||||
: Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {}
|
||||
|
||||
~Gemma();
|
||||
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
|
|
|
|||
|
|
@ -24,10 +24,12 @@
|
|||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h"
|
||||
#include "util/args.h" // IWYU pragma: export
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -35,7 +37,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path,
|
||||
const std::string& weights_path) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
|
|
@ -139,6 +143,9 @@ struct RuntimeConfig {
|
|||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
// Which attention implementation to use.
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlash;
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
|
|
@ -159,10 +166,15 @@ struct RuntimeConfig {
|
|||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
|
||||
// Whether to use continuous batching.
|
||||
bool use_continuous_batching = false;
|
||||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
||||
|
|
@ -187,6 +199,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
// For prompts longer than the Linux terminal's 4K line edit buffer.
|
||||
Path prompt_file;
|
||||
std::string eot_line;
|
||||
std::string attention_impl;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -240,6 +253,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
visitor(attention_impl, "attention_impl", std::string("flash"),
|
||||
"Attention implementation to use. See configs.cc for options.", 2);
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
|
|
@ -261,36 +276,39 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
runtime_config.temperature = temperature;
|
||||
runtime_config.top_k = top_k;
|
||||
runtime_config.attention_impl = GetAttentionImpl(attention_impl);
|
||||
}
|
||||
};
|
||||
|
||||
struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
ClientArgs() { Init(); };
|
||||
// Bundles all args required to construct a `GemmaEnv` or the equivalent.
|
||||
struct GemmaArgs {
|
||||
// For callers that do not parse command line args.
|
||||
GemmaArgs(const LoaderArgs& loader,
|
||||
const ThreadingArgs& threading = ThreadingArgs(),
|
||||
const InferenceArgs& inference = InferenceArgs())
|
||||
: loader(loader), threading(threading), inference(inference) {}
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
std::string prompt;
|
||||
bool interactive;
|
||||
GemmaArgs(int argc, char** argv, ConsumedArgs& consumed)
|
||||
: loader(argc, argv, consumed),
|
||||
threading(argc, argv, consumed),
|
||||
inference(argc, argv, consumed) {}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(host, "host", std::string("localhost"),
|
||||
"Server host (default: localhost)");
|
||||
visitor(port, "port", 8080,
|
||||
"Server port (default: 8080)");
|
||||
visitor(api_key, "api_key", std::string(""),
|
||||
"Use public API with key (changes host to "
|
||||
"generativelanguage.googleapis.com:443)");
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name to use (default: gemma3-4b)");
|
||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||
"Prompt for generation (default: 'Hello! How are you?')");
|
||||
visitor(interactive, "interactive", false,
|
||||
"Start interactive chat mode (0 = no, 1 = yes)");
|
||||
void Help() {
|
||||
fprintf(stderr,
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n"
|
||||
"\n*Model Loading Arguments*\n");
|
||||
loader.Help();
|
||||
fprintf(stderr, "\n*Threading Arguments*\n");
|
||||
threading.Help();
|
||||
fprintf(stderr, "\n*Inference Arguments*\n");
|
||||
inference.Help();
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
LoaderArgs loader;
|
||||
ThreadingArgs threading;
|
||||
InferenceArgs inference;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
#include "gemma/gemma_args.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void FillPtrs(const std::vector<std::string>& args, std::vector<char*>& ptrs) {
|
||||
ptrs.reserve(args.size());
|
||||
for (const std::string& arg : args) {
|
||||
ptrs.push_back(const_cast<char*>(arg.data()));
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllConsumed(const std::vector<std::string>& args) {
|
||||
std::vector<char*> ptrs;
|
||||
FillPtrs(args, ptrs);
|
||||
const int argc = static_cast<int>(args.size());
|
||||
char** argv = const_cast<char**>(ptrs.data());
|
||||
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs gemma_args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
}
|
||||
|
||||
static void CheckUnconsumed(const std::vector<std::string>& args,
|
||||
size_t expected) {
|
||||
std::vector<char*> ptrs;
|
||||
FillPtrs(args, ptrs);
|
||||
const int argc = static_cast<int>(args.size());
|
||||
char** argv = const_cast<char**>(ptrs.data());
|
||||
|
||||
ConsumedArgs consumed(argc, argv);
|
||||
GemmaArgs gemma_args(argc, argv, consumed);
|
||||
ASSERT_EQ(expected, consumed.FirstUnconsumed());
|
||||
}
|
||||
|
||||
// Note: do not use --help because that is not actually consumed; it is actually
|
||||
// special-cased in `HasHelp`.
|
||||
TEST(GemmaArgsTest, AllConsumedArgs) {
|
||||
// Single arg
|
||||
CheckAllConsumed({"gemma", "--weights=x"});
|
||||
// Two args, one with =
|
||||
CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"});
|
||||
// Two args, one with extra value
|
||||
CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"});
|
||||
// Two args with values
|
||||
CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"});
|
||||
}
|
||||
|
||||
TEST(GemmaArgsTest, UnconsumedArgs) {
|
||||
// Single unconsumed arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED"}, 1);
|
||||
// Single unconsumed arg, no --
|
||||
CheckUnconsumed({"gemma", "UNDEFINED"}, 1);
|
||||
// Single unconsumed arg after valid arg
|
||||
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2);
|
||||
// Single unconsumed arg before valid arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1);
|
||||
// Single unconsumed arg with = after valid arg
|
||||
CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2);
|
||||
// Single unconsumed arg with = before valid arg
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1);
|
||||
// Multiple unconsumed args
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1);
|
||||
// Multiple unconsumed args with valid arg between
|
||||
CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -16,6 +16,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
|
|
@ -50,8 +51,16 @@ KVCache KVCache::Copy() {
|
|||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
|
||||
std::vector<KVCachePtr> ptrs;
|
||||
ptrs.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
ptrs.push_back(kv_caches[i].ToPtr());
|
||||
}
|
||||
return ptrs;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
#include "util/basics.h" // BF16
|
||||
|
|
@ -27,18 +30,33 @@ namespace gcpp {
|
|||
|
||||
using KV_t = float;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
bool IsEmpty() const { return kv_cache.Rows() == 0; }
|
||||
size_t SeqLen() const;
|
||||
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
};
|
||||
|
||||
struct KVCache {
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator);
|
||||
|
||||
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||
// copy ctor to make the cost explicit.
|
||||
KVCache Copy();
|
||||
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
size_t SeqLen() const {
|
||||
return kv_cache.Rows();
|
||||
}
|
||||
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
KVCachePtr ToPtr() {
|
||||
return KVCachePtr{
|
||||
.kv_cache = kv_cache,
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
const Allocator& allocator_;
|
||||
|
||||
|
|
@ -46,6 +64,13 @@ struct KVCache {
|
|||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||
};
|
||||
|
||||
inline size_t KVCachePtr::SeqLen() const {
|
||||
return kv_cache.Rows();
|
||||
}
|
||||
|
||||
// Convenience function to create views into KVCaches.
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
TEST(KVCacheTest, KVCacheToPtrs) {
|
||||
ModelConfig model_config;
|
||||
model_config.max_seq_len = 1024;
|
||||
model_config.num_layers = 2;
|
||||
for (int i = 0; i < model_config.num_layers; ++i) {
|
||||
model_config.layer_configs.push_back(LayerConfig());
|
||||
model_config.layer_configs.back().kv_heads = 4;
|
||||
model_config.layer_configs.back().qkv_dim = 256;
|
||||
}
|
||||
InferenceArgs inference_args;
|
||||
inference_args.seq_len = 1024;
|
||||
RuntimeConfig runtime_config;
|
||||
runtime_config.attention_impl = AttentionImpl::kFlash;
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
std::vector<KVCache> caches;
|
||||
caches.emplace_back(model_config, inference_args, runtime_config,
|
||||
ctx.allocator);
|
||||
inference_args.seq_len = 512;
|
||||
caches.emplace_back(model_config, inference_args, runtime_config,
|
||||
ctx.allocator);
|
||||
|
||||
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
|
||||
ASSERT_EQ(ptrs.size(), 2);
|
||||
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
|
||||
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -221,6 +221,8 @@ static size_t DeduceNumLayers(const KeyVec& keys) {
|
|||
// This works with or without type prefixes because it searches for substrings.
|
||||
static int DeduceLayerTypes(const BlobReader& reader) {
|
||||
int layer_types = 0;
|
||||
bool has_key_norm = false;
|
||||
bool has_query_norm = false;
|
||||
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
|
||||
const std::string& key = reader.Keys()[key_idx];
|
||||
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
|
||||
|
|
@ -232,6 +234,15 @@ static int DeduceLayerTypes(const BlobReader& reader) {
|
|||
layer_types |= kDeduced448;
|
||||
}
|
||||
}
|
||||
if (key.find("key_norm") != std::string::npos) { // NOLINT
|
||||
has_key_norm = true;
|
||||
}
|
||||
if (key.find("query_norm") != std::string::npos) { // NOLINT
|
||||
has_query_norm = true;
|
||||
}
|
||||
}
|
||||
if (has_key_norm && has_query_norm) {
|
||||
layer_types |= kDeducedKqNorm;
|
||||
}
|
||||
return layer_types;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,186 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct PerQuery {
|
||||
PromptTokens prompt;
|
||||
|
||||
// Position in the KV cache: initially zero for the first turn, or when
|
||||
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
|
||||
size_t mutable_pos;
|
||||
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
|
||||
// which might differ from `prompt.size() - 1` for prefix-LM.
|
||||
size_t initial_pos;
|
||||
// Zero for causal attention, or the end of the prefix for prefix-LM style
|
||||
// attention in Paligemma.
|
||||
size_t prefix_end;
|
||||
|
||||
KVCachePtr kv_cache;
|
||||
|
||||
// Previous token generated for this query, or the last prompt token. Will be
|
||||
// fed into the next Transformer() call.
|
||||
int prev_token = 0;
|
||||
};
|
||||
|
||||
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
|
||||
struct AllQueries {
|
||||
AllQueries() = default;
|
||||
|
||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCachePtr>& kv_caches) {
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompt,
|
||||
.mutable_pos = pos,
|
||||
.initial_pos = pos,
|
||||
.prefix_end = prefix_end,
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches)
|
||||
: AllQueries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
|
||||
|
||||
// Batch of queries with initial position set to zero. Causal attention
|
||||
// is requested via empty or all-zero `prefix_end`.
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCachePtr>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
per_query_.reserve(prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches.size() == 0 ||
|
||||
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompts[i],
|
||||
.mutable_pos = 0,
|
||||
.initial_pos = 0,
|
||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
|
||||
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
|
||||
prefix_end) {}
|
||||
|
||||
void Reserve(size_t size) { per_query_.reserve(size); }
|
||||
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
||||
|
||||
size_t NumQueries() const { return per_query_.size(); }
|
||||
|
||||
PerQuery& operator[](size_t query_idx) {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
const PerQuery& operator[](size_t query_idx) const {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PerQuery> per_query_;
|
||||
};
|
||||
|
||||
// View into AllQueries: either a batch of queries, or a single query for use
|
||||
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
|
||||
// reference to AllQueries.
|
||||
class QBatch {
|
||||
public:
|
||||
QBatch(size_t start, size_t max_size, AllQueries& queries)
|
||||
: start_(start),
|
||||
max_size_(max_size),
|
||||
queries_(queries),
|
||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
query_idx_.reserve(size_);
|
||||
for (size_t i = 0; i < size_; ++i) {
|
||||
query_idx_.push_back(start_ + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a single-query view starting at `qi` relative to this batch.
|
||||
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
|
||||
|
||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||
size_t Size() const { return size_; }
|
||||
|
||||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||
size_t QueryIdx(size_t qi) const {
|
||||
HWY_DASSERT(qi < size_);
|
||||
return query_idx_[qi];
|
||||
}
|
||||
|
||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||
const PromptTokens& Prompt(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prompt;
|
||||
}
|
||||
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t InitialPos(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].initial_pos;
|
||||
}
|
||||
size_t PrefixEnd(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prefix_end;
|
||||
}
|
||||
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
|
||||
// let query_idx_[to] point to the from in the queries_; this is only used if
|
||||
// the slot in the QBatch is less than the number of queries.
|
||||
void Insert(size_t from, size_t to) {
|
||||
if (from == to) return;
|
||||
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
|
||||
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
|
||||
// Conceptually, insert from.query to location to.
|
||||
query_idx_[to] = from;
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t start_;
|
||||
size_t max_size_;
|
||||
AllQueries& queries_;
|
||||
std::vector<size_t> query_idx_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_
|
||||
54
gemma/run.cc
54
gemma/run.cc
|
|
@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) {
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
|
||||
void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
const InferenceArgs& inference = args.inference;
|
||||
const int verbosity = inference.verbosity;
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
|
|
@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity,
|
||||
.use_spinning = args.threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens, env);
|
||||
if (inference.verbosity >= 1) {
|
||||
if (verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
|
|
@ -129,11 +131,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// callback function invoked for each generated token.
|
||||
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
|
||||
float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
||||
HWY_ASSERT(pos == abs_pos);
|
||||
++abs_pos;
|
||||
|
||||
std::string token_text;
|
||||
if (!gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
|
||||
if (token == -2) return true; // Gemma 3 ViT?
|
||||
HWY_WARN("Failed to decode token %d.", token);
|
||||
}
|
||||
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
|
|
@ -185,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.batch_stream_token = batch_stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
.use_spinning = args.threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
std::vector<int> prompt;
|
||||
size_t prefix_end = 0;
|
||||
|
|
@ -248,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
}
|
||||
|
||||
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
void Run(const GemmaArgs& args) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
ThreadingContext ctx(threading);
|
||||
ThreadingContext ctx(args.threading);
|
||||
MatMulEnv env(ctx);
|
||||
const InferenceArgs& inference = args.inference;
|
||||
if (inference.verbosity >= 3) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, ctx);
|
||||
const Gemma gemma(args, ctx);
|
||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
@ -283,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
if (inference.IsInteractive()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, threading, inference, gemma.Config(),
|
||||
gemma.WeightReadMode(), ctx);
|
||||
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, gemma, kv_cache, env);
|
||||
ReplGemma(args, gemma, kv_cache, env);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -298,17 +303,24 @@ int main(int argc, char** argv) {
|
|||
gcpp::InternalInit();
|
||||
{
|
||||
// Negligible CPU time.
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, threading, inference);
|
||||
fprintf(stderr,
|
||||
"\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights gemma2-2b-it-sfp.sbs\n\n");
|
||||
args.Help();
|
||||
return 0;
|
||||
}
|
||||
|
||||
gcpp::Run(loader, threading, inference);
|
||||
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::Run(args);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/tensor_stats.h"
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "io/io.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/profiler.h" // StringTable
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
float Correlation(const float* x, size_t num) {
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
sum += x[i];
|
||||
}
|
||||
const double mean = sum / static_cast<double>(num);
|
||||
|
||||
double numerator = 0.0;
|
||||
double sum_sq_current = 0.0;
|
||||
double sum_sq_next = 0.0;
|
||||
|
||||
for (size_t i = 0; i < num - 1; ++i) {
|
||||
const double diff_current = static_cast<double>(x[i]) - mean;
|
||||
const double diff_next = static_cast<double>(x[i + 1]) - mean;
|
||||
|
||||
numerator += diff_current * diff_next;
|
||||
sum_sq_current += diff_current * diff_current;
|
||||
sum_sq_next += diff_next * diff_next;
|
||||
}
|
||||
|
||||
if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f;
|
||||
const double denominator = std::sqrt(sum_sq_current * sum_sq_next);
|
||||
const float corr = static_cast<float>(numerator / denominator);
|
||||
HWY_DASSERT(-1.0f <= corr && corr <= 1.0f);
|
||||
return corr;
|
||||
}
|
||||
|
||||
// Only write tensor data the first time it is encountered per layer. This is
|
||||
// a concurrent string+layer -> flag map which avoids std::mutex (incompatible
|
||||
// with fibers). We use a string table to index into per-layer atomic flags.
|
||||
static bool ShouldWrite(const char* name, size_t layer_idx) {
|
||||
constexpr size_t kMaxNames = 128;
|
||||
constexpr size_t kMaxLayers = 128;
|
||||
HWY_DASSERT(layer_idx < kMaxLayers);
|
||||
static hwy::StringTable<kMaxNames> s_table;
|
||||
const size_t name_idx = s_table.Add(name);
|
||||
static std::atomic_flag flags[kMaxNames * kMaxLayers] = {};
|
||||
return !flags[name_idx * kMaxLayers + layer_idx].test_and_set(
|
||||
std::memory_order_acq_rel);
|
||||
}
|
||||
|
||||
std::unique_ptr<File> MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased,
|
||||
const Path& tensor_output) {
|
||||
if (tensor_output.Empty()) return nullptr;
|
||||
if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr;
|
||||
char path[1024];
|
||||
snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin",
|
||||
tensor_output.path.c_str(), type_erased.Name(), layer_idx,
|
||||
type_erased.Rows(), type_erased.Cols(),
|
||||
TypeName(type_erased.GetType()));
|
||||
return OpenFileOrAbort(Path(path), "wb");
|
||||
}
|
||||
|
||||
void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
|
||||
size_t row_idx) {
|
||||
if (!file) return;
|
||||
const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes();
|
||||
file->Write(type_erased.RowBytes(row_idx), bytes_per_row,
|
||||
bytes_per_row * row_idx);
|
||||
}
|
||||
|
||||
// First dispatch to the type, then parallel over rows, then vectorized
|
||||
// decompress and Notify for each value.
|
||||
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
|
||||
const MatPtr& type_erased, ThreadingContext& ctx, int flags,
|
||||
size_t cluster_idx, Parallelism parallelism) {
|
||||
std::unique_ptr<File> file =
|
||||
MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output);
|
||||
|
||||
if ((flags & kTensorStatsIsWeight) && layer_idx != 0) {
|
||||
// Still compute stats, but remember not to print them.
|
||||
stats.Get(layer_idx, 0).DoNotPrint();
|
||||
}
|
||||
|
||||
CallUpcasted(&type_erased, [&](const auto* mat) {
|
||||
const size_t cols = mat->Cols();
|
||||
|
||||
ParallelFor(
|
||||
parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats,
|
||||
[&](size_t row_idx, size_t global_idx) {
|
||||
GCPP_ZONE(ctx, global_idx, Zones::kGenStats);
|
||||
|
||||
auto* HWY_RESTRICT row = mat->Row(row_idx);
|
||||
MaybeWriteRow(file, type_erased, row_idx);
|
||||
|
||||
using Packed = hwy::RemoveCvRef<decltype(*row)>;
|
||||
PackedSpan<Packed> packed(const_cast<Packed*>(row), cols);
|
||||
|
||||
TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx);
|
||||
my_stats.NotifyCond(ConditionNumber(row, cols));
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
|
||||
|
||||
size_t packed_ofs = 0;
|
||||
if (cols >= 2 * NF) {
|
||||
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
|
||||
VF v0, v1;
|
||||
Decompress2(df, packed, packed_ofs, v0, v1);
|
||||
hn::Store(v0, df, buf);
|
||||
hn::Store(v1, df, buf + NF);
|
||||
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
|
||||
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
|
||||
const float min = hn::ReduceMin(df, min_mag);
|
||||
if (min != 0.0f) { // Avoid division by zero.
|
||||
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < 2 * NF; ++i) {
|
||||
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
|
||||
}
|
||||
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
|
||||
}
|
||||
}
|
||||
|
||||
// Zero to two vectors remaining.
|
||||
for (; packed_ofs < cols; packed_ofs += NF) {
|
||||
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
|
||||
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
|
||||
// Skip NotifyGroup for this partial group.
|
||||
for (size_t i = 0; i < remaining; ++i) {
|
||||
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
|
||||
}
|
||||
my_stats.NotifyCorr(Correlation(buf, remaining));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_EXPORT(UpdateStatsT);
|
||||
|
||||
// Must reside in .cc file so that we can #include compress-inl.h.
|
||||
void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased,
|
||||
ThreadingContext& ctx, int flags, size_t cluster_idx,
|
||||
Parallelism parallelism) {
|
||||
// Ignore empty tensors.
|
||||
if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) {
|
||||
return;
|
||||
}
|
||||
HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags,
|
||||
cluster_idx, parallelism);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#ifndef GCPP_TENSOR_STATS
|
||||
#define GCPP_TENSOR_STATS 0
|
||||
#endif
|
||||
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/stats.h"
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// For flags. Used to inhibit printing per-layer stats for weights.
|
||||
HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1;
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
|
||||
HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192;
|
||||
|
||||
// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`.
|
||||
// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify``
|
||||
// to see the distribution of per-layer means.
|
||||
struct TensorStatsAcrossLayers {
|
||||
bool IsEmpty() const { return s_frobenius.Count() == 0; }
|
||||
|
||||
void Print() {
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str());
|
||||
if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) {
|
||||
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
|
||||
}
|
||||
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
|
||||
if (s_exp_subnormal.Min() != 0.0f) {
|
||||
fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str());
|
||||
}
|
||||
if (s_big_cols.Count() != 0) {
|
||||
fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str());
|
||||
const size_t modal_col = b_big_cols.ModalBinIdx();
|
||||
const size_t num_outlier_cols = b_big_cols.NumNonzero();
|
||||
if (num_outlier_cols > 256) {
|
||||
fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n",
|
||||
b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col));
|
||||
} else {
|
||||
fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col,
|
||||
b_big_cols.Bin(modal_col));
|
||||
for (size_t i = 0; i < kStatsMaxCols; ++i) {
|
||||
if (b_big_cols.Bin(i) > 2) {
|
||||
fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
hwy::Stats s_frobenius;
|
||||
|
||||
hwy::Stats s_cond_min;
|
||||
hwy::Stats s_cond_avg;
|
||||
hwy::Stats s_cond_max;
|
||||
|
||||
hwy::Stats s_val_min;
|
||||
hwy::Stats s_val_avg;
|
||||
hwy::Stats s_val_kurt;
|
||||
|
||||
hwy::Stats s_mag_min;
|
||||
hwy::Stats s_mag_avg;
|
||||
hwy::Stats s_mag_max;
|
||||
|
||||
hwy::Stats s_corr_avg;
|
||||
hwy::Stats s_corr_max;
|
||||
|
||||
hwy::Stats s_range_avg;
|
||||
|
||||
hwy::Stats s_exp_min;
|
||||
hwy::Stats s_exp_max;
|
||||
hwy::Stats s_exp_mode;
|
||||
hwy::Stats s_exp_subnormal;
|
||||
|
||||
hwy::Stats s_big_cols; // total number of outlier cols
|
||||
hwy::Bins<kStatsMaxCols> b_big_cols; // # layers with outlier per col
|
||||
};
|
||||
|
||||
// Per-thread and layer.
|
||||
class TensorStatsAccumulator {
|
||||
public:
|
||||
void Notify(float val, size_t row_idx, size_t col_idx) {
|
||||
const double dval = static_cast<double>(val);
|
||||
sum_sq_ += dval * dval;
|
||||
|
||||
s_val_.Notify(val);
|
||||
const float mag = hwy::ScalarAbs(val);
|
||||
|
||||
if (HWY_UNLIKELY(mag >= 64.0f)) {
|
||||
if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx);
|
||||
if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx);
|
||||
}
|
||||
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag);
|
||||
|
||||
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(mag);
|
||||
// Use biased exponent because Bins wants unsigned values.
|
||||
const uint32_t biased_exp = binary32 >> 23;
|
||||
HWY_DASSERT(biased_exp < 256); // already cleared sign bit
|
||||
b_exp256_.Notify(biased_exp);
|
||||
}
|
||||
|
||||
void DoNotPrint() { skip_.fetch_or(1); }
|
||||
bool ShouldPrint() const { return skip_.load() == 0; }
|
||||
|
||||
// Vector code computed the min/max of a group (= two vectors); this is
|
||||
// faster than doing it in `Notify`.
|
||||
void NotifyGroup(float min, float max) {
|
||||
s_group_min_.Notify(min);
|
||||
s_group_max_.Notify(max);
|
||||
// Caller ensures min != 0.
|
||||
s_group_range_.Notify(max / min);
|
||||
}
|
||||
|
||||
void NotifyCorr(float corr) { s_corr_.Notify(corr); }
|
||||
void NotifyCond(double cond) { s_cond_.Notify(cond); }
|
||||
|
||||
void Assimilate(const TensorStatsAccumulator& other) {
|
||||
skip_.fetch_or(other.skip_.load());
|
||||
|
||||
sum_sq_ += other.sum_sq_;
|
||||
b_exp256_.Assimilate(other.b_exp256_);
|
||||
b_big_row_.Assimilate(other.b_big_row_);
|
||||
b_big_col_.Assimilate(other.b_big_col_);
|
||||
s_val_.Assimilate(other.s_val_);
|
||||
s_mag_.Assimilate(other.s_mag_);
|
||||
s_corr_.Assimilate(other.s_corr_);
|
||||
s_group_min_.Assimilate(other.s_group_min_);
|
||||
s_group_max_.Assimilate(other.s_group_max_);
|
||||
s_group_range_.Assimilate(other.s_group_range_);
|
||||
}
|
||||
|
||||
// Called on the per-layer representative after reducing across threads.
|
||||
void NotifyAcrossLayer(TensorStatsAcrossLayers& s) {
|
||||
s.s_frobenius.Notify(std::sqrt(sum_sq_));
|
||||
|
||||
s.s_cond_min.Notify(s_cond_.Min());
|
||||
s.s_cond_avg.Notify(s_cond_.Mean());
|
||||
s.s_cond_max.Notify(s_cond_.Max());
|
||||
|
||||
s.s_val_min.Notify(s_val_.Min());
|
||||
s.s_val_avg.Notify(s_val_.Mean());
|
||||
s.s_val_kurt.Notify(s_val_.Kurtosis());
|
||||
|
||||
s.s_mag_min.Notify(s_mag_.Min());
|
||||
s.s_mag_avg.Notify(s_mag_.Mean());
|
||||
s.s_mag_max.Notify(s_mag_.Max());
|
||||
|
||||
s.s_corr_avg.Notify(s_corr_.Mean());
|
||||
s.s_corr_max.Notify(s_corr_.Max());
|
||||
|
||||
s.s_range_avg.Notify(s_group_range_.Mean());
|
||||
|
||||
const uint32_t subnormals = b_exp256_.Bin(0);
|
||||
// Prevent subnormals from hiding the min exponent.
|
||||
b_exp256_.ResetBin(0);
|
||||
s.s_exp_min.Notify(b_exp256_.FirstNonzero());
|
||||
s.s_exp_max.Notify(b_exp256_.LastNonzero());
|
||||
s.s_exp_mode.Notify(b_exp256_.ModalBinIdx());
|
||||
s.s_exp_subnormal.Notify(subnormals);
|
||||
|
||||
const uint32_t num_outliers = b_big_col_.NumNonzero();
|
||||
if (num_outliers != 0) {
|
||||
s.s_big_cols.Notify(num_outliers);
|
||||
// For each col, count the number of layers that have an outlier there.
|
||||
for (size_t i = 0; i < kStatsMaxCols; ++i) {
|
||||
if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsEmpty() const { return s_val_.Count() == 0; }
|
||||
|
||||
void PrintAll() {
|
||||
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
|
||||
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
|
||||
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
|
||||
b_exp256_.Print("exp");
|
||||
PrintBinRanges(b_big_row_, "big row");
|
||||
PrintBinRanges(b_big_col_, "big col");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
template <size_t N>
|
||||
void PrintBinRanges(const hwy::Bins<N>& b, const char* name) {
|
||||
uint64_t total = 0;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
total += b.Bin(i);
|
||||
}
|
||||
if (total == 0) return;
|
||||
|
||||
// If all bins are at least 10% of a uniform distribution, print the range
|
||||
// to vastly reduce the log size.
|
||||
const size_t min = HWY_MAX(1, total / (N * 10));
|
||||
size_t last = 0;
|
||||
for (; last < N; ++last) {
|
||||
if (b.Bin(last) < min) break;
|
||||
}
|
||||
if (last >= N / 2) {
|
||||
// Also require all subsequent bins to be zero, otherwise we should
|
||||
// print the outlier bins.
|
||||
bool all_zero = true;
|
||||
for (size_t i = last + 1; i < N; ++i) {
|
||||
if (b.Bin(last) != 0) {
|
||||
all_zero = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_zero) {
|
||||
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
b.Print(name, /*skip_zero=*/true);
|
||||
}
|
||||
|
||||
double sum_sq_ = 0.0; // for Frobenius norm
|
||||
hwy::Bins<256> b_exp256_; // exponent
|
||||
hwy::Bins<kMaxBatchSize> b_big_row_;
|
||||
hwy::Bins<kStatsMaxCols> b_big_col_;
|
||||
hwy::Stats s_val_;
|
||||
hwy::Stats s_mag_;
|
||||
hwy::Stats s_cond_; // condition number
|
||||
hwy::Stats s_corr_; // lag-1 autocorrelation
|
||||
hwy::Stats s_group_min_;
|
||||
hwy::Stats s_group_max_;
|
||||
hwy::Stats s_group_range_;
|
||||
std::atomic<int> skip_{0};
|
||||
};
|
||||
|
||||
class TensorStats {
|
||||
public:
|
||||
TensorStats(size_t num_layers, size_t max_workers)
|
||||
: num_layers_(num_layers),
|
||||
max_workers_(max_workers),
|
||||
acc_(num_layers * max_workers) {}
|
||||
|
||||
// Parallelized across rows. If `ctx.tensor_output` is not empty, writes
|
||||
// tensor data to disk for offline analysis, once per tensor and layer.
|
||||
void Notify(size_t layer_idx, const MatPtr& type_erased,
|
||||
ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0,
|
||||
Parallelism parallelism = Parallelism::kFlat);
|
||||
|
||||
// For use by `UpdateStatsT`.
|
||||
TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) {
|
||||
const size_t idx = layer_idx * max_workers_ + global_idx;
|
||||
HWY_DASSERT(idx < acc_.size());
|
||||
return acc_[idx];
|
||||
}
|
||||
|
||||
void ReduceAndPrint(const char* prefix) {
|
||||
for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) {
|
||||
TensorStatsAccumulator& per_layer = Get(layer_idx, 0);
|
||||
for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) {
|
||||
per_layer.Assimilate(Get(layer_idx, global_idx));
|
||||
}
|
||||
if (per_layer.IsEmpty()) continue;
|
||||
per_layer.NotifyAcrossLayer(across_layers_);
|
||||
if (per_layer.ShouldPrint()) {
|
||||
fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx);
|
||||
per_layer.PrintAll();
|
||||
}
|
||||
}
|
||||
|
||||
if (!across_layers_.IsEmpty()) {
|
||||
fprintf(stderr, "================= across layers %s\n", prefix);
|
||||
across_layers_.Print();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_layers_;
|
||||
size_t max_workers_;
|
||||
std::vector<TensorStatsAccumulator> acc_;
|
||||
TensorStatsAcrossLayers across_layers_;
|
||||
};
|
||||
|
||||
#else // GCPP_TENSOR_STATS
|
||||
|
||||
class TensorStats {
|
||||
public:
|
||||
TensorStats(size_t, size_t) {}
|
||||
void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0,
|
||||
Parallelism = Parallelism::kFlat) {}
|
||||
void ReduceAndPrint(const char*) {}
|
||||
};
|
||||
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
19
gemma/vit.cc
19
gemma/vit.cc
|
|
@ -78,13 +78,9 @@ class VitAttention {
|
|||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Shift Q, K, VT to MatStorageT.
|
||||
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
|
||||
env_.ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
|
||||
env_.ctx.allocator, MatPadding::kPacked);
|
||||
MatPtrT<float>& Q = activations_.attention.vit_Q;
|
||||
MatPtrT<float>& K = activations_.attention.vit_K;
|
||||
MatPtrT<float>& C = activations_.attention.vit_C;
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
ZeroInit(activations_.attention.att_out);
|
||||
|
|
@ -295,19 +291,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
const size_t patch_width = model_config.vit_config.patch_width;
|
||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
const size_t patch_area = patch_width * patch_width * 3;
|
||||
const hwy::Divisor div_patch_dim(patch_width);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
(void)model_dim;
|
||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
|
||||
env.ctx.allocator, MatPadding::kOdd);
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
image.GetPatch(i, image_patches.Row(i));
|
||||
image.GetPatch(i, div_patch_dim, image_patches.Row(i));
|
||||
}
|
||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
||||
|
|
|
|||
|
|
@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
|||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
const size_t cluster_idx = 0;
|
||||
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
|
||||
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
|
|
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
|||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
|
||||
TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
|
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||
? ParallelismStrategy::kHierarchical
|
||||
: ParallelismStrategy::kFlat;
|
||||
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
const Parallelism parallelism =
|
||||
HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
|
||||
ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
|
|
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
|
|||
const std::vector<IOBatch>& batches,
|
||||
ThreadingContext& ctx) {
|
||||
// >5x speedup from parallel reads when cached.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
|
||||
ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kReadBatches,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
|
||||
|
|
|
|||
|
|
@ -96,7 +96,8 @@ struct LayerWeightsPtrs {
|
|||
// other values for purposes of the KV cache.
|
||||
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
||||
const TensorInfoRegistry& tensors)
|
||||
: finder_(LayerSuffix(layer_idx), tensors),
|
||||
: layer_idx(layer_idx),
|
||||
finder_(LayerSuffix(layer_idx), tensors),
|
||||
qkv_einsum_w(finder_("qkv_ein")),
|
||||
qkv_einsum_w1(finder_("qkv1_w")),
|
||||
qkv_einsum_w2(finder_("qkv2_w")),
|
||||
|
|
@ -135,6 +136,7 @@ struct LayerWeightsPtrs {
|
|||
}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
const size_t layer_idx;
|
||||
const MatFinder finder_;
|
||||
|
||||
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
|||
ThreadingContext& ctx, size_t cluster_idx) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
HWY_ASSERT(ranges.size() == blobs.size());
|
||||
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
|
||||
ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
|
||||
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||
reader.file().Read(ranges[i].offset, ranges[i].bytes,
|
||||
|
|
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
|
|||
const double t0 = hwy::platform::Now();
|
||||
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
|
||||
ctx.pools.NumClusters());
|
||||
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||
ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||
[&](const size_t task, size_t cluster_idx) {
|
||||
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
|
||||
task ? blobs1 : blobs2, ctx, cluster_idx);
|
||||
|
|
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
const double t0 = hwy::platform::Now();
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
|
||||
ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
|
||||
Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
|
|
|
|||
|
|
@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
|
|||
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
|
||||
static_cast<const uint8_t*>(data), writes);
|
||||
|
||||
const ParallelismStrategy strategy = file_->IsAppendOnly()
|
||||
? ParallelismStrategy::kNone
|
||||
: ParallelismStrategy::kFlat;
|
||||
const Parallelism parallelism =
|
||||
file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
|
||||
ParallelFor(
|
||||
strategy, writes.size(), ctx_,
|
||||
parallelism, writes.size(), ctx_,
|
||||
/*cluster_idx=*/0, Callers::kBlobWriter,
|
||||
[this, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange& range = writes[i].range;
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class BlobWriter {
|
|||
std::vector<size_t> blob_sizes_;
|
||||
ThreadingContext& ctx_;
|
||||
// Current offset in the file used for writing.
|
||||
int64_t curr_offset_ = 0;
|
||||
uint64_t curr_offset_ = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||
Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
|
||||
std::to_string(i).c_str());
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
|
|
|||
14
io/io.cc
14
io/io.cc
|
|
@ -110,7 +110,8 @@ class FilePosix : public File {
|
|||
HWY_WARN(
|
||||
"Read failure at pos %zu within size %zu with offset %zu and "
|
||||
"errno %d\n",
|
||||
pos, size, offset, errno);
|
||||
static_cast<size_t>(pos), static_cast<size_t>(size),
|
||||
static_cast<size_t>(offset), errno);
|
||||
break;
|
||||
}
|
||||
pos += bytes_read;
|
||||
|
|
@ -130,7 +131,8 @@ class FilePosix : public File {
|
|||
HWY_WARN(
|
||||
"Write failure at pos %zu within size %zu with offset %zu and "
|
||||
"errno %d\n",
|
||||
pos, size, offset, errno);
|
||||
static_cast<size_t>(pos), static_cast<size_t>(size),
|
||||
static_cast<size_t>(offset), errno);
|
||||
break;
|
||||
}
|
||||
pos += bytes_written;
|
||||
|
|
@ -194,9 +196,9 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
|||
namespace gcpp {
|
||||
|
||||
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, mode);
|
||||
if (!file) {
|
||||
HWY_ABORT("Failed to open %s", filename.path.c_str());
|
||||
HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno);
|
||||
}
|
||||
return file;
|
||||
}
|
||||
|
|
@ -234,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void InternalInit() {
|
||||
int InternalInit() {
|
||||
// currently unused, except for init list ordering in GemmaEnv.
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t IOBatch::Read(const File& file) const {
|
||||
|
|
|
|||
2
io/io.h
2
io/io.h
|
|
@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
|
|||
|
||||
// No-op in open-source. Must be called at the beginning of a binary, before
|
||||
// any I/O or flag usage.
|
||||
void InternalInit();
|
||||
int InternalInit();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ namespace gcpp {
|
|||
namespace {
|
||||
|
||||
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
|
||||
Path output_weights;
|
||||
|
||||
|
|
@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase<WriterArgs> {
|
|||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::WriterArgs args(argc, argv);
|
||||
if (args.output_weights.Empty()) {
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
gcpp::WriterArgs writer_args(argc, argv, consumed);
|
||||
if (writer_args.output_weights.Empty()) {
|
||||
HWY_ABORT("Missing --output_weights flag, a file for the model weights.");
|
||||
}
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
|
||||
gcpp::GemmaEnv env(args);
|
||||
env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -413,7 +413,8 @@ using DotKernelDefault =
|
|||
template <class D, typename WT, typename VT>
|
||||
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||
const VT* HWY_RESTRICT vec, size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
|
||||
return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
|
||||
DotKernelDefault());
|
||||
}
|
||||
|
||||
// Adapter for two pointers, no bounds checking.
|
||||
|
|
|
|||
|
|
@ -891,18 +891,6 @@ class DotStats {
|
|||
hwy::Stats s_times[kVariants];
|
||||
};
|
||||
|
||||
// Returns normalized value in [-1, 1).
|
||||
float RandomFloat(RngStream& rng) {
|
||||
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
|
||||
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
|
||||
const uint32_t representation = exp | (rng() & mantissa_mask);
|
||||
const float f12 = hwy::BitCastScalar<float>(representation);
|
||||
HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
|
||||
const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
|
||||
HWY_DASSERT(-1.0f <= f && f < 1.0f);
|
||||
return f;
|
||||
}
|
||||
|
||||
// `raw` holds the decompressed values, so that the test measures only the
|
||||
// error from the Dot algorithms, not the compression.
|
||||
template <typename Packed>
|
||||
|
|
@ -1126,7 +1114,7 @@ void TestAllDot() {
|
|||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||
Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||
[&](size_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Row(thread);
|
||||
float* HWY_RESTRICT pb = b.Row(thread);
|
||||
|
|
|
|||
|
|
@ -837,10 +837,11 @@ class MMImpl {
|
|||
hwy::platform::InvariantTicksPerSecond();
|
||||
const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA
|
||||
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
|
||||
fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n",
|
||||
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(),
|
||||
cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()),
|
||||
cfg.InnerTasks());
|
||||
fprintf(
|
||||
stderr,
|
||||
"%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu\n",
|
||||
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), cfg.MC(),
|
||||
cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), cfg.InnerTasks());
|
||||
}
|
||||
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
|
||||
const auto ratio = [&tuner](uint64_t ticks) -> double {
|
||||
|
|
@ -850,7 +851,8 @@ class MMImpl {
|
|||
const MMConfig& best = *tuner.Best();
|
||||
fprintf(
|
||||
stderr,
|
||||
"\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
|
||||
"\n%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu, "
|
||||
"%.2fx,%.2fx\n",
|
||||
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
|
||||
best.KC(), best.NC(), StringFromOrder(best.Order()),
|
||||
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
|
||||
|
|
@ -906,8 +908,8 @@ class MMLoops {
|
|||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT);
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K
|
||||
|
||||
parallel.ForN(
|
||||
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
|
||||
|
|
@ -941,7 +943,7 @@ class MMLoops {
|
|||
const MMArgs& args) {
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K);
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M
|
||||
|
||||
parallel.ForN(args.env.ctx, args.range_n,
|
||||
MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks,
|
||||
|
|
@ -977,7 +979,7 @@ class MMLoops {
|
|||
const MMArgs& args) {
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT);
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K
|
||||
|
||||
parallel.ForRangesMC_NC(
|
||||
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||
|
|
@ -1158,8 +1160,9 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
|||
HWY_ASSERT(K <= MMEntireA::kMaxK);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config));
|
||||
const size_t max_M = MMKeys::BucketM(M);
|
||||
tuner.SetCandidates(MMCandidates(cache, max_M, K, N, num_B, sizeof(BF16),
|
||||
env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
|
|
|
|||
167
ops/matmul.cc
167
ops/matmul.cc
|
|
@ -21,6 +21,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -46,7 +47,9 @@ size_t RoundDownWithFloor(size_t value, size_t multiple) {
|
|||
// multiple of `multiple`, or 0 if none exists.
|
||||
size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
||||
const size_t multiple) {
|
||||
HWY_DASSERT(end != 0 && dim != 0 && multiple != 0);
|
||||
HWY_DASSERT(end != 0);
|
||||
HWY_DASSERT(dim != 0);
|
||||
HWY_DASSERT(multiple != 0);
|
||||
size_t prev = RoundDownWithFloor(end, multiple);
|
||||
// Avoid returning `end` if rounding down had no effect.
|
||||
if (prev == end) prev -= multiple;
|
||||
|
|
@ -62,10 +65,10 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
// and holds most of their arguments in member variables.
|
||||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
|
||||
GenerateCandidates(const CacheInfo& cache, size_t max_M, size_t K, size_t N,
|
||||
size_t num_B, size_t sizeof_TC, bool print_config)
|
||||
: cache_(cache),
|
||||
M_(M),
|
||||
max_M_(max_M),
|
||||
K_(K),
|
||||
N_(N),
|
||||
num_B_(num_B),
|
||||
|
|
@ -89,14 +92,14 @@ class GenerateCandidates {
|
|||
for (size_t mc : MC(mr, kc, order)) {
|
||||
for (size_t nc : NC(mr, mc, kc, order)) {
|
||||
for (int inner_tasks : all_inner_tasks) {
|
||||
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
|
||||
nc_multiple_, order, inner_tasks);
|
||||
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
|
||||
const MMConfig config(max_M_, K_, N_, mr, mc, kc, nc,
|
||||
kc_multiple_, nc_multiple_, order,
|
||||
inner_tasks);
|
||||
const size_t M_tasks = config.RangesOfMC(max_M_).NumTasks();
|
||||
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
|
||||
|
||||
// Blocks only make sense when there are multiple M tasks.
|
||||
if (IsBlock(order) != (M_tasks > 1)) continue;
|
||||
// Single KC only makes sense when there is a single K task.
|
||||
// Do not use single-MC/KC order if there are multiple.
|
||||
if (IsOneMC(order) != (M_tasks == 1)) continue;
|
||||
if (IsOneKC(order) != (K_tasks == 1)) continue;
|
||||
|
||||
candidates.push_back(config);
|
||||
|
|
@ -114,6 +117,25 @@ class GenerateCandidates {
|
|||
private:
|
||||
using SizeVec = std::vector<size_t>;
|
||||
|
||||
// Concatenate and print once because this can be called concurrently.
|
||||
void MaybePrintSizes(size_t dim, size_t max, const char* caption,
|
||||
const SizeVec& sizes) const {
|
||||
if (!print_config_ || sizes.empty()) return;
|
||||
std::string out("num_B ");
|
||||
out += std::to_string(num_B_);
|
||||
out += " (";
|
||||
out += std::to_string(dim);
|
||||
out += ", max ";
|
||||
out += std::to_string(max);
|
||||
out += ") ";
|
||||
out += caption;
|
||||
out += ": ";
|
||||
for (size_t size : sizes) {
|
||||
out += std::to_string(size) + " ";
|
||||
}
|
||||
fprintf(stderr, "%s\n", out.c_str());
|
||||
}
|
||||
|
||||
// How many rows of A per call to `MMKernel::LoopKC`. Lower values may
|
||||
// be better for SIMD targets with fewer registers.
|
||||
SizeVec MR() const {
|
||||
|
|
@ -125,14 +147,14 @@ class GenerateCandidates {
|
|||
SizeVec all_mr;
|
||||
all_mr.reserve(3);
|
||||
// AVX2's 16 registers are not enough for four rows, but SSE4 may benefit.
|
||||
if (M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR);
|
||||
if (max_M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR);
|
||||
// Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also
|
||||
// enable if not enough rows for 4.
|
||||
if (M_ >= 2 && (M_ < kMaxMR || (!is_sse && !is_wasm))) {
|
||||
if (max_M_ >= 2 && (max_M_ < kMaxMR || (!is_sse && !is_wasm))) {
|
||||
all_mr.push_back(size_t{2});
|
||||
}
|
||||
// Even SSE4 usually prefers 2 rows; only enable for single rows.
|
||||
if (M_ == 1) all_mr.push_back(size_t{1});
|
||||
if (max_M_ == 1) all_mr.push_back(size_t{1});
|
||||
HWY_ASSERT(!all_mr.empty());
|
||||
return all_mr;
|
||||
}
|
||||
|
|
@ -143,18 +165,26 @@ class GenerateCandidates {
|
|||
for (size_t order_idx = 0;; ++order_idx) {
|
||||
const MMOrder order = static_cast<MMOrder>(order_idx);
|
||||
if (StringFromOrder(order) == nullptr) return orders; // done
|
||||
// 2D blocking is useless for a single row of M.
|
||||
if (IsBlock(order) && M_ <= mr) continue;
|
||||
// Multiple-MC is useless for a single row of M.
|
||||
if (!IsOneMC(order) && max_M_ <= mr) continue;
|
||||
// Conversely, N-only parallelism is uncompetitive for large M.
|
||||
if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue;
|
||||
if (IsOneMC(order) && max_M_ >= 8 * mr) continue;
|
||||
orders.push_back(order);
|
||||
}
|
||||
}
|
||||
|
||||
// The number of A and B columns to read between updating `C`.
|
||||
SizeVec KC(size_t mr, MMOrder order) const {
|
||||
if (IsOneKC(order)) {
|
||||
// A single KC range is infeasible when K exceeds the max. The caller
|
||||
// will skip all configs with `order`.
|
||||
if (K_ > kMaxKC) return SizeVec();
|
||||
// Must return the actual value: although ignored by `RangesOfKC`, this
|
||||
// will be used in MC() and NC().
|
||||
return SizeVec(1, K_);
|
||||
}
|
||||
// `LoopKC` handles up to `mr` rows of A.
|
||||
const size_t rows_a = HWY_MIN(M_, mr);
|
||||
const size_t rows_a = HWY_MIN(max_M_, mr);
|
||||
|
||||
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
|
||||
// `buf`. To amortize the write cost, we want to maximize `kc`. However, it
|
||||
|
|
@ -186,7 +216,7 @@ class GenerateCandidates {
|
|||
|
||||
// If we can afford a single K task, that's usually best; only try one
|
||||
// more. Otherwise, blocks may require smaller kc (more options).
|
||||
const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2;
|
||||
const size_t reps = (kc_max == K_) ? 1 : IsOneMC(order) ? 2 : 3;
|
||||
|
||||
size_t prev = kc_max;
|
||||
for (size_t rep = 0; rep < reps; ++rep) {
|
||||
|
|
@ -196,22 +226,27 @@ class GenerateCandidates {
|
|||
}
|
||||
}
|
||||
|
||||
if (print_config_ && all_kc.size() > 1) {
|
||||
fprintf(stderr, "num_B %zu: KC: ", num_B_);
|
||||
for (size_t kc : all_kc) {
|
||||
fprintf(stderr, "%zu ", kc);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
MaybePrintSizes(K_, kc_max, "KC", all_kc);
|
||||
return all_kc;
|
||||
}
|
||||
|
||||
// The number of (L2 resident) A rows for `A2C0` to loop over.
|
||||
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
|
||||
if (max_M_ <= mr) return SizeVec(1, max_M_);
|
||||
if (IsOneMC(order)) {
|
||||
// A single MC range is infeasible when M exceeds the max. The caller
|
||||
// will skip all configs with `order`.
|
||||
if (max_M_ > kMaxMC) return SizeVec();
|
||||
// Must return the actual value: although ignored by `RangesOfMC`, this
|
||||
// will be used in NC().
|
||||
return SizeVec(1, max_M_);
|
||||
}
|
||||
|
||||
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
|
||||
// it is typically inclusive.
|
||||
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
|
||||
// `kc` was chosen to fit in L1, hence this should not exceed L2.
|
||||
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
|
||||
|
||||
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
|
|
@ -219,35 +254,45 @@ class GenerateCandidates {
|
|||
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
mc_max = HWY_MIN(mc_max, M_);
|
||||
mc_max = hwy::RoundDownTo(mc_max, mr);
|
||||
mc_max = HWY_MIN(mc_max, max_M_);
|
||||
HWY_ASSERT(mc_max != 0);
|
||||
|
||||
SizeVec all_mc(1, mc_max);
|
||||
// Larger MC is better for non-blocks, otherwise we want more small options,
|
||||
// especially for two B.
|
||||
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
|
||||
SizeVec all_mc;
|
||||
all_mc.reserve(6);
|
||||
|
||||
size_t prev = mc_max;
|
||||
for (size_t rep = 0; rep < reps; ++rep) {
|
||||
prev = PrevDivisor(1, prev, M_, mr);
|
||||
if (prev >= mc_max || prev == 0) break;
|
||||
const size_t rounded_M = HWY_MAX(mr, hwy::RoundDownTo(max_M_, mr));
|
||||
size_t prev = hwy::RoundDownTo(mc_max, mr);
|
||||
|
||||
// If mc_max is large enough, allow using the whole range without rounding
|
||||
// down (which may require two ranges).
|
||||
if (mc_max == max_M_ && (max_M_ % mr) != 0) {
|
||||
all_mc.push_back(max_M_);
|
||||
// The next option should be considerably smaller than `max_M_`.
|
||||
prev = HWY_MAX(mr, hwy::RoundDownTo(3 * prev / 4, mr));
|
||||
} else {
|
||||
all_mc.push_back(prev);
|
||||
}
|
||||
|
||||
// Blocks: largest is not useful.
|
||||
if (IsBlock(order) && all_mc.size() > 1) {
|
||||
all_mc.erase(all_mc.begin(), all_mc.begin() + 1);
|
||||
// We know `order` is multiple MC, where more/smaller values of `mc` are
|
||||
// helpful, especially for two B, hence add iterations.
|
||||
const size_t reps = 2 + num_B_;
|
||||
for (size_t rep = 0; rep < reps; ++rep) {
|
||||
prev = PrevDivisor(mr, prev, rounded_M, mr);
|
||||
if (prev == 0) break; // none found
|
||||
if (prev == mr) {
|
||||
if (all_mc.back() != prev) all_mc.push_back(prev);
|
||||
break;
|
||||
}
|
||||
if (prev <= mc_max / 8) break;
|
||||
all_mc.push_back(prev);
|
||||
}
|
||||
|
||||
if (print_config_ && all_mc.size() > 1) {
|
||||
fprintf(stderr, "num_B %zu: MC: ", num_B_);
|
||||
for (size_t mc : all_mc) {
|
||||
fprintf(stderr, "%zu ", mc);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
if (all_mc.size() <= 2) {
|
||||
if (max_M_ > mr) all_mc.push_back(max_M_ / 2);
|
||||
if (mc_max > mr) all_mc.push_back(mc_max / 2);
|
||||
}
|
||||
|
||||
MaybePrintSizes(max_M_, mc_max, "MC", all_mc);
|
||||
return all_mc;
|
||||
}
|
||||
|
||||
|
|
@ -257,7 +302,7 @@ class GenerateCandidates {
|
|||
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
||||
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
|
||||
// leave it unbounded.
|
||||
if (M_ > mr) {
|
||||
if (max_M_ > mr) {
|
||||
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
|
||||
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
|
||||
}
|
||||
|
|
@ -271,8 +316,8 @@ class GenerateCandidates {
|
|||
nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_);
|
||||
}
|
||||
|
||||
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`.
|
||||
if (!IsBlock(order)) return SizeVec(1, N_);
|
||||
// Single-MC calls `ForNP`, which ignores `range_nc`.
|
||||
if (IsOneMC(order)) return SizeVec(1, N_);
|
||||
|
||||
SizeVec all_nc(1, nc_max);
|
||||
|
||||
|
|
@ -282,7 +327,7 @@ class GenerateCandidates {
|
|||
// hence autotune a wider range of nc than the other dimensions.
|
||||
size_t reps = 9 + num_B_;
|
||||
// For small M, we can afford larger NC, hence allow fewer small options.
|
||||
if (M_ <= 2 * mr) reps -= 1;
|
||||
if (max_M_ <= 2 * mr) reps -= 1;
|
||||
|
||||
size_t prev = nc_max;
|
||||
for (size_t rep = 0; rep < reps; ++rep) {
|
||||
|
|
@ -302,14 +347,7 @@ class GenerateCandidates {
|
|||
all_nc.begin() + HWY_MIN(want_delete, max_delete));
|
||||
}
|
||||
|
||||
if (print_config_ && all_nc.size() > 1) {
|
||||
fprintf(stderr, "num_B %zu: NC: ", num_B_);
|
||||
for (size_t nc : all_nc) {
|
||||
fprintf(stderr, "%zu ", nc);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
MaybePrintSizes(N_, nc_max, "NC", all_nc);
|
||||
return all_nc;
|
||||
}
|
||||
|
||||
|
|
@ -319,8 +357,8 @@ class GenerateCandidates {
|
|||
std::vector<int> inner_tasks;
|
||||
inner_tasks.reserve(3);
|
||||
inner_tasks.push_back(1);
|
||||
// Blocks have one task per mc/nc range and ignore this parameter.
|
||||
if (!IsBlock(order)) {
|
||||
// Multiple-MC have one task per mc/nc range and ignore this parameter.
|
||||
if (IsOneMC(order)) {
|
||||
inner_tasks.push_back(2);
|
||||
inner_tasks.push_back(4);
|
||||
}
|
||||
|
|
@ -328,7 +366,7 @@ class GenerateCandidates {
|
|||
}
|
||||
|
||||
const CacheInfo& cache_;
|
||||
const size_t M_;
|
||||
const size_t max_M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
const size_t num_B_;
|
||||
|
|
@ -343,10 +381,11 @@ class GenerateCandidates {
|
|||
} // namespace
|
||||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||
size_t N, size_t num_B, size_t sizeof_TC,
|
||||
bool print_config) {
|
||||
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)();
|
||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t max_M,
|
||||
size_t K, size_t N, size_t num_B,
|
||||
size_t sizeof_TC, bool print_config) {
|
||||
return GenerateCandidates(cache, max_M, K, N, num_B, sizeof_TC,
|
||||
print_config)();
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||
|
|
|
|||
233
ops/matmul.h
233
ops/matmul.h
|
|
@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
|||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||
|
||||
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
||||
// Policy classes for parallelism, implementing some of `Parallelism`.
|
||||
|
||||
struct MMParallelNone {
|
||||
template <class Func>
|
||||
|
|
@ -103,17 +103,13 @@ struct MMParallelWithinCluster {
|
|||
template <class Func>
|
||||
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
|
||||
size_t inner_tasks, size_t cluster_idx, const Func& func) const {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const hwy::pool::Caller caller =
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForN);
|
||||
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
|
||||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
ParallelizeOneRange(ranges_n, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForN),
|
||||
ParallelPartitionWithinCluster(
|
||||
range_n, n_multiple, inner_tasks, ctx, cluster_idx, caller,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, base + worker);
|
||||
func(worker_range, worker);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -122,80 +118,57 @@ struct MMParallelWithinCluster {
|
|||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
||||
const Func& func) const {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
const hwy::pool::Caller caller =
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMCNC);
|
||||
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
ParallelizeOneRange(ranges_nc, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
|
||||
[&](const IndexRange& range_nc, size_t worker) {
|
||||
func(ranges_mc.Range(0), range_nc, base + worker);
|
||||
// We are running on one pool, hence collapse into a 1D range.
|
||||
const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
|
||||
const auto get_mc = [&](uint64_t task) {
|
||||
return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
|
||||
};
|
||||
const auto get_nc = [&](uint64_t task) {
|
||||
return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
|
||||
};
|
||||
const size_t num_tasks = ranges_mc.NumTasks() * ranges_nc.NumTasks();
|
||||
|
||||
ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
func(get_mc(task), get_nc(task), worker);
|
||||
});
|
||||
} else {
|
||||
ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) { func(range_mc, range_nc, base + worker); });
|
||||
}
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t cluster_idx, const Func& func) const {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
const hwy::pool::Caller caller =
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMC);
|
||||
|
||||
cluster.Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMC),
|
||||
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
|
||||
ParallelForWithinCluster(
|
||||
range_mc.Num(), ctx, cluster_idx, caller,
|
||||
[&](uint64_t i, size_t worker) { func(range_mc.begin() + i, worker); });
|
||||
}
|
||||
};
|
||||
|
||||
struct MMParallelHierarchical {
|
||||
// Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is
|
||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||
// Similar to `HierarchicalParallelFor`, but over *sub-ranges* of B rows in
|
||||
// `range_n` governed by `n_multiple` and `inner_tasks`.
|
||||
template <class Func>
|
||||
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
|
||||
size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
|
||||
size_t inner_tasks, size_t caller_cluster_idx,
|
||||
const Func& func) const {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
HWY_DASSERT(caller_cluster_idx == 0);
|
||||
(void)caller_cluster_idx;
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
|
||||
|
||||
// Single cluster: parallel-for over static partition of `range_n`.
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
if (num_clusters == 1) {
|
||||
const size_t cluster_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
return ParallelizeOneRange(
|
||||
ranges_n, cluster, caller,
|
||||
// Assign clusters (if any) a sub-range of `range_n` (typically hundreds).
|
||||
ParallelPartitionAcrossClusters(
|
||||
range_n, n_multiple, /*inner_tasks=*/1, ctx, caller,
|
||||
[&](const IndexRange& cluster_range, size_t cluster_idx) {
|
||||
ParallelPartitionWithinCluster(
|
||||
cluster_range, n_multiple, inner_tasks, ctx, cluster_idx, caller,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, worker);
|
||||
});
|
||||
}
|
||||
|
||||
// Assign each cluster a sub-range of `range_n` (typically hundreds).
|
||||
const IndexRangePartition ranges_n =
|
||||
StaticPartition(range_n, num_clusters, n_multiple);
|
||||
ParallelizeOneRange(
|
||||
ranges_n, all_clusters, caller,
|
||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
ParallelizeOneRange(
|
||||
worker_ranges, cluster, caller,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, cluster_base + worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -205,69 +178,56 @@ struct MMParallelHierarchical {
|
|||
void ForRangesMC_NC(ThreadingContext& ctx,
|
||||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc,
|
||||
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
|
||||
const Func& func) const {
|
||||
size_t caller_cluster_idx, const Func& func) const {
|
||||
HWY_DASSERT(caller_cluster_idx == 0);
|
||||
(void)caller_cluster_idx;
|
||||
const hwy::pool::Caller caller =
|
||||
ctx.pool_callers.Get(Callers::kMMHierForMCNC);
|
||||
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||
// to reduce the number of fork-joins.
|
||||
if (num_clusters == 1) {
|
||||
const size_t cluster_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
return ParallelizeOneRange(
|
||||
ranges_nc, cluster, caller,
|
||||
[&](const IndexRange& range_nc, size_t worker) {
|
||||
func(ranges_mc.Range(0), range_nc, worker);
|
||||
});
|
||||
} else {
|
||||
return ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster, caller,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) { func(range_mc, range_nc, worker); });
|
||||
}
|
||||
}
|
||||
// Collapse two range indices into a 1D range for better load-balancing,
|
||||
// because `ranges_mc` may just have one task.
|
||||
const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
|
||||
const auto get_mc = [&](uint64_t task) {
|
||||
return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
|
||||
};
|
||||
const auto get_nc = [&](uint64_t task) {
|
||||
return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
|
||||
};
|
||||
const IndexRange all_range(0, ranges_mc.NumTasks() * ranges_nc.NumTasks());
|
||||
|
||||
// Multiple clusters: N across clusters (both are usually the larger), and
|
||||
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
|
||||
ParallelizeOneRange(
|
||||
ranges_nc, all_clusters, caller,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
ParallelizeOneRange(ranges_mc, cluster, caller,
|
||||
[&](const IndexRange& range_mc, size_t worker) {
|
||||
func(range_mc, range_nc, cluster_base + worker);
|
||||
ParallelPartitionAcrossClusters(
|
||||
all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller,
|
||||
[&](const IndexRange& cluster_range, size_t cluster_idx) {
|
||||
ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx,
|
||||
caller, [&](uint64_t i, size_t worker) {
|
||||
const size_t task =
|
||||
cluster_range.begin() + i;
|
||||
func(get_mc(task), get_nc(task), worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(row_a, worker)` in parallel.
|
||||
// No multiple/inner_tasks, so this is just HierarchicalParallelFor.
|
||||
template <class Func>
|
||||
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t caller_cluster_idx, const Func& func) const {
|
||||
HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC,
|
||||
[&](size_t task, size_t worker) {
|
||||
func(range_mc.begin() + task, worker);
|
||||
});
|
||||
HWY_DASSERT(caller_cluster_idx == 0);
|
||||
(void)caller_cluster_idx;
|
||||
HierarchicalParallelFor(
|
||||
range_mc.Num(), ctx, Callers::kMMHierForMC,
|
||||
[&](size_t i, size_t worker) { func(range_mc.begin() + i, worker); });
|
||||
}
|
||||
};
|
||||
|
||||
template <class Func, typename... Args>
|
||||
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
|
||||
void DispatchParallelism(Parallelism parallelism, const Func& func,
|
||||
Args&&... args) {
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone:
|
||||
case Parallelism::kNone:
|
||||
return func(MMParallelNone(), std::forward<Args>(args)...);
|
||||
case ParallelismStrategy::kWithinCluster:
|
||||
case Parallelism::kWithinCluster:
|
||||
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
|
||||
case ParallelismStrategy::kHierarchical:
|
||||
case Parallelism::kHierarchical:
|
||||
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
|
|
@ -371,8 +331,8 @@ void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
|
|||
}
|
||||
}
|
||||
|
||||
static inline bool IsBlock(MMOrder order) {
|
||||
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
|
||||
static inline bool IsOneMC(MMOrder order) {
|
||||
return order == MMOrder::kNT || order == MMOrder::kNT_K;
|
||||
}
|
||||
|
||||
static inline bool IsOneKC(MMOrder order) {
|
||||
|
|
@ -421,6 +381,8 @@ static inline const char* StringFromParA(MMParA par_a) {
|
|||
// `mc` := A rows such that `kc` columns fit in L2,
|
||||
// `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C.
|
||||
// Also includes loop order and task granularity.
|
||||
//
|
||||
// This is shared by multiple M which return the same `BucketM`.
|
||||
#pragma pack(push, 1)
|
||||
class MMConfig {
|
||||
public:
|
||||
|
|
@ -428,8 +390,8 @@ class MMConfig {
|
|||
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
|
||||
// `MMOrder` is how to parallelize the outer loops.
|
||||
// `inner_tasks` chooses the within-cluster task granularity in `ForN`.
|
||||
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
|
||||
size_t kc_multiple, size_t nc_multiple, MMOrder order,
|
||||
MMConfig(size_t M, size_t K, size_t N, size_t mr, size_t mc, size_t kc,
|
||||
size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order,
|
||||
int inner_tasks)
|
||||
: mr_(static_cast<uint32_t>(mr)),
|
||||
mc_(static_cast<uint32_t>(mc)),
|
||||
|
|
@ -441,11 +403,7 @@ class MMConfig {
|
|||
inner_tasks_(static_cast<uint8_t>(inner_tasks)),
|
||||
reserved_{} {
|
||||
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
|
||||
if (mc % mr != 0) {
|
||||
HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr);
|
||||
}
|
||||
// Do not warn for single-kc tasks; some models unfortunately have K which
|
||||
// are not multiples of `kc_multiple`.
|
||||
// Some models have K which are not multiples of `kc_multiple`.
|
||||
if (kc != K && (kc % kc_multiple) != 0) {
|
||||
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
|
||||
}
|
||||
|
|
@ -457,11 +415,21 @@ class MMConfig {
|
|||
}
|
||||
|
||||
// Splits M/N into blocks which are visited sequentially or in parallel.
|
||||
// K is always sequential, see `MMOrder`.
|
||||
IndexRangePartition RangesOfMC(size_t M) const {
|
||||
return MaxSizePartition(IndexRange(0, M), mc_, mr_);
|
||||
if (IsOneMC(order_)) {
|
||||
// Must have exactly one M range/tile, regardless of `mr_` and `mc_`.
|
||||
return IndexRangePartition(M);
|
||||
}
|
||||
const size_t mc = HWY_MIN(M, MC());
|
||||
const size_t mr = HWY_MIN(M, MR());
|
||||
return MaxSizePartition(IndexRange(0, M), mc, mr);
|
||||
}
|
||||
// K is either a single range, or a sequential loop.
|
||||
IndexRangePartition RangesOfKC(size_t K) const {
|
||||
if (IsOneKC(order_)) {
|
||||
// Must have exactly one K range/tile, regardless of `kc_`.
|
||||
return IndexRangePartition(K);
|
||||
}
|
||||
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
|
||||
}
|
||||
IndexRangePartition RangesOfNC(size_t N) const {
|
||||
|
|
@ -488,7 +456,7 @@ class MMConfig {
|
|||
uint32_t kc_multiple_;
|
||||
MMOrder order_;
|
||||
uint8_t inner_tasks_;
|
||||
HWY_MAYBE_UNUSED uint8_t reserved_[6];
|
||||
HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t reserved_[6];
|
||||
};
|
||||
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||
#pragma pack(pop)
|
||||
|
|
@ -597,26 +565,27 @@ class MMAutoTune {
|
|||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which
|
||||
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
|
||||
// but choosing the same config for a larger M can result in multiple MC ranges.
|
||||
// Thus M less than this must have unique keys/configs.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8;
|
||||
|
||||
// Map of previously seen dimensions to index via linear search.
|
||||
class MMKeys {
|
||||
// Group batch size into buckets to reduce #auto-tunes.
|
||||
static size_t BucketM(size_t M) {
|
||||
if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above.
|
||||
if (M <= 128) return 128;
|
||||
return 512;
|
||||
}
|
||||
|
||||
public:
|
||||
using Key = uint64_t;
|
||||
// KeyFromDims will only return this if all dims are zero, which is invalid.
|
||||
static constexpr Key kPadding = 0;
|
||||
|
||||
// Returns the maximum permissible M in the bucket, for grouping batch sizes
|
||||
// into buckets to reduce #auto-tunes.
|
||||
static size_t BucketM(size_t M) {
|
||||
HWY_DASSERT(M != 0);
|
||||
// Small M: 1..3, 4..7, 8..15, etc. share the same config.
|
||||
if (M < 64) return M | (kMaxMR - 1);
|
||||
// Larger M use power of two buckets: 64..127, 128..255, etc.
|
||||
const size_t floor_log2_M =
|
||||
31 - hwy::Num0BitsAboveMS1Bit_Nonzero32(static_cast<uint32_t>(M));
|
||||
const size_t min_M = size_t{1} << floor_log2_M;
|
||||
HWY_DASSERT(min_M <= M && M < 2 * min_M);
|
||||
return 2 * min_M - 1;
|
||||
}
|
||||
|
||||
// Compresses the dimensions into a single Key for faster comparison.
|
||||
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
|
||||
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||
|
|
@ -747,7 +716,7 @@ class MMOptions {
|
|||
const void* opaque = nullptr;
|
||||
|
||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||
Parallelism parallelism = Parallelism::kHierarchical;
|
||||
};
|
||||
|
||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||
|
|
|
|||
|
|
@ -195,9 +195,10 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
|
|||
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
|
||||
const IndexRangePartition get_col_c =
|
||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||
ParallelizeOneRange(
|
||||
get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest),
|
||||
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
|
||||
ParallelForAcrossClusters(
|
||||
get_col_c.NumTasks(), env.ctx, env.ctx.pool_callers.Get(Callers::kTest),
|
||||
[&](size_t range_idx, size_t cluster_idx) HWY_ATTR {
|
||||
const IndexRange cols_c = get_col_c.Range(range_idx);
|
||||
for (size_t r : all_rows_c) {
|
||||
TC* HWY_RESTRICT C_row = C.Row(r);
|
||||
for (size_t c : cols_c) {
|
||||
|
|
|
|||
|
|
@ -25,9 +25,11 @@
|
|||
#include <cstdint>
|
||||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/matmul.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // TokenAndProb, RngStream
|
||||
#include "util/mat.h"
|
||||
|
|
@ -61,6 +63,9 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Computes C = A * B + add via MatMulStatic.
|
||||
// This function uses CallUpcasted to dispatch to the correct MatMulStatic
|
||||
// instantiation based on the runtime type of B.
|
||||
template <typename TA, typename TC>
|
||||
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
|
|
@ -497,10 +502,10 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||
HWY_DASSERT(activations.SameShape(out));
|
||||
activations.DebugCheckSameShape(out);
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||
ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
|
||||
cluster_idx, Callers::kOpsRMSNormBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||
|
|
@ -517,7 +522,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
|||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsRMSNormInplaceBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||
|
|
@ -550,7 +555,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
|||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
|
||||
});
|
||||
|
|
@ -1122,9 +1127,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
// TODO: support bf16 logits using Decompress2.
|
||||
// Computes softmax probabilities for the given logits, normalizing in-place.
|
||||
// The calculation is numerically stable, using the max-subtraction trick to
|
||||
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
|
||||
// If temperature is provided and not 1.0, each intermediate exp() result is
|
||||
// divided by temperature before normalization; however, this division by
|
||||
// temperature cancels out during the final normalization step, meaning
|
||||
// temperature currently has no effect on the output probabilities.
|
||||
// @param logits In-out: on input, contains logits; on output, overwritten with
|
||||
// probabilities.
|
||||
// @param ctx Input: threading context for parallelism and profiling.
|
||||
// @param worker Input: worker thread index.
|
||||
// @param temperature Input: softmax temperature.
|
||||
// @param softmax_max_out Optional output: if not null, stores the max logit
|
||||
// value.
|
||||
// @param softmax_d_out Optional output: if softmax_max is not null, this must
|
||||
// not be null and stores the sum of exp(logit - max).
|
||||
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
const size_t worker, float temperature = 1.0f,
|
||||
const SMOptions& sm_options = {}) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
|
||||
HWY_DASSERT(logits.size() != 0);
|
||||
|
||||
|
|
@ -1168,6 +1189,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
|||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, logits.data(), logits.size());
|
||||
if (sm_options.max_out) {
|
||||
*sm_options.max_out = hn::GetLane(vmax);
|
||||
*sm_options.d_out = sum_exp;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -1290,7 +1315,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
|||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||
if (cap == 0.0f) return;
|
||||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsMaybeLogitsSoftCapBatched,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
|||
return inv_timescale;
|
||||
}
|
||||
|
||||
struct SMOptions {
|
||||
float* HWY_RESTRICT max_out = nullptr;
|
||||
float* HWY_RESTRICT d_out = nullptr;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "util/zones.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -38,7 +37,6 @@
|
|||
#include "util/mat.h" // MatStorageT
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -348,6 +346,51 @@ void TestAllSoftmax() {
|
|||
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
||||
}
|
||||
|
||||
class TestSoftmaxState {
|
||||
public:
|
||||
template <class D>
|
||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||
hwy::RandomState& rng) {
|
||||
if (count == 0) return; // *Softmax would assert
|
||||
if (misalign_b == 0) return;
|
||||
using T = hn::TFromD<D>;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||
HWY_ASSERT(px && pe);
|
||||
|
||||
T* x = px.get() + misalign_a;
|
||||
T* initial_logits = pe.get() + misalign_a;
|
||||
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
x[i] = Random<T>(rng);
|
||||
initial_logits[i] = x[i];
|
||||
}
|
||||
|
||||
float softmax_max;
|
||||
float softmax_d;
|
||||
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
|
||||
{.max_out = &softmax_max, .d_out = &softmax_d});
|
||||
|
||||
const float maxval =
|
||||
*std::max_element(initial_logits, initial_logits + count);
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
sum_exp += std::exp(initial_logits[i] - maxval);
|
||||
}
|
||||
|
||||
ASSERT_NEAR(softmax_max, maxval, 1e-6);
|
||||
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllSoftmaxState() {
|
||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
|
||||
}
|
||||
|
||||
template <size_t k>
|
||||
struct TestCreateDistribution {
|
||||
void operator()(hwy::RandomState& rng) {
|
||||
|
|
@ -456,7 +499,7 @@ void TestRopeAndMulBy() {
|
|||
x.Row(0)[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
constexpr float kmul = 1.0f;
|
||||
|
||||
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
|
||||
|
|
@ -771,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ cc_test(
|
|||
deps = [
|
||||
":image",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
// Hardcoded for PaliGemma ViT input.
|
||||
constexpr size_t kPatchSize = 14;
|
||||
|
||||
// Returns the linearly scaled index in [0, to_size) closest to the
|
||||
// value in [0, from_size).
|
||||
|
|
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
|
|||
}
|
||||
|
||||
// Image.data() is H x W x 3.
|
||||
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
|
||||
void Image::GetPatch(size_t patch_num, float* patch) const {
|
||||
// We want the N-th patch of size patch_dim x patch_dim x 3.
|
||||
void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||
float* patch) const {
|
||||
PROFILER_FUNC;
|
||||
constexpr size_t kNumChannels = 3;
|
||||
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float));
|
||||
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel);
|
||||
const size_t kDataSize = width_ * height_ * kNumChannels;
|
||||
constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
|
||||
const size_t patch_dim = div_patch_dim.GetDivisor();
|
||||
const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
|
||||
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
|
||||
HWY_ASSERT(size() == kDataSize);
|
||||
HWY_ASSERT(width_ % kPatchSize == 0);
|
||||
HWY_ASSERT(height_ % kPatchSize == 0);
|
||||
const size_t kNumPatchesPerRow = width_ / kPatchSize;
|
||||
size_t patch_y = patch_num / kNumPatchesPerRow;
|
||||
size_t patch_x = patch_num % kNumPatchesPerRow;
|
||||
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize);
|
||||
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow);
|
||||
patch_y *= kPatchSize;
|
||||
patch_x *= kPatchSize;
|
||||
HWY_ASSERT(size() == width_ * height_ * kNumChannels);
|
||||
HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
|
||||
HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
|
||||
const size_t patches_x = div_patch_dim.Divide(width_);
|
||||
size_t patch_y = patch_num / patches_x;
|
||||
size_t patch_x = patch_num % patches_x;
|
||||
HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
|
||||
HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
|
||||
patch_y *= patch_dim;
|
||||
patch_x *= patch_dim;
|
||||
|
||||
// Move `out` and `in` to the start of the patch.
|
||||
char* out = reinterpret_cast<char*>(patch);
|
||||
|
|
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
|
|||
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
|
||||
|
||||
// Copy the patch one row at a time.
|
||||
for (size_t y = 0; y < kPatchSize; ++y) {
|
||||
std::memcpy(out, in, kBytesPerRow);
|
||||
out += kBytesPerRow;
|
||||
for (size_t y = 0; y < patch_dim; ++y) {
|
||||
std::memcpy(out, in, bytes_per_row);
|
||||
out += bytes_per_row;
|
||||
in += in_bytes_to_next_row;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // Divisor
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -44,11 +45,12 @@ class Image {
|
|||
bool WriteBinary(const std::string& filename) const;
|
||||
// Stores the patch for the given patch number in `patch`.
|
||||
// Patches are numbered in usual raster-order. E.g. for an image of size
|
||||
// 224 x 224, there are 16 x 16 = 256 patches.
|
||||
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
|
||||
// 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
|
||||
// `patch` should have space for at least patch_dim * patch_dim * 3.
|
||||
// Requires that Normalize() has been called and that the image width and
|
||||
// height are multiples of 14.
|
||||
void GetPatch(size_t patch_num, float* patch) const;
|
||||
// height are multiples of patch_dim.
|
||||
void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||
float* patch) const;
|
||||
|
||||
float *data() { return data_.data(); }
|
||||
const float *data() const { return data_.data(); }
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
|
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
|
|||
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
|
||||
// Extract two patches.
|
||||
float patch[588];
|
||||
image.GetPatch(0, patch);
|
||||
const hwy::Divisor div_patch_dim(14);
|
||||
image.GetPatch(0, div_patch_dim, patch);
|
||||
EXPECT_EQ(patch[0], Normalize(160));
|
||||
EXPECT_EQ(patch[1], Normalize(184));
|
||||
EXPECT_EQ(patch[2], Normalize(188));
|
||||
image.GetPatch(18, patch);
|
||||
image.GetPatch(18, div_patch_dim, patch);
|
||||
// Check the first row of the patch.
|
||||
for (size_t i = 0; i < 14 * 3; ++i) {
|
||||
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
|
||||
|
|
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
|
|||
// Extract two patches.
|
||||
const size_t kPatchValues = 14 * 14 * 3; // = 588
|
||||
float patch[kPatchValues];
|
||||
const hwy::Divisor div_patch_dim(14);
|
||||
// Patch 0 is just the "start" of the image.
|
||||
image.GetPatch(0, patch);
|
||||
image.GetPatch(0, div_patch_dim, patch);
|
||||
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
|
||||
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
|
||||
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
|
||||
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
|
||||
// pixel coordinates are offset by (14, 28).
|
||||
image.GetPatch(6, patch);
|
||||
image.GetPatch(6, div_patch_dim, patch);
|
||||
for (size_t n = 0; n < kPatchValues; ++n) {
|
||||
size_t k = n % 3;
|
||||
size_t j = ((n - k) / 3) % 14;
|
||||
|
|
|
|||
|
|
@ -21,10 +21,9 @@
|
|||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "io/io.h"
|
||||
#include "paligemma/paligemma_helper.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "paligemma/paligemma_helper.h"
|
||||
|
||||
// This test can be run manually with the downloaded PaliGemma weights.
|
||||
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||
|
|
@ -72,9 +71,12 @@ TEST_F(PaliGemmaTest, QueryObjects) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
gcpp::InternalInit();
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
gcpp::GemmaArgs args(argc, argv, consumed);
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
gcpp::GemmaEnv env(args);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
|
|
|
|||
|
|
@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
|
||||
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
|
||||
.def_readwrite("internal", &ModelConfig::internal)
|
||||
.def_readwrite("use_global_timescale",
|
||||
&ModelConfig::use_global_timescale)
|
||||
|
||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
||||
arg("layer_config"))
|
||||
|
|
|
|||
|
|
@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
|||
// Wrapper around GemmaEnv to expose to Python.
|
||||
class GemmaModel {
|
||||
public:
|
||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading,
|
||||
const gcpp::InferenceArgs& inference)
|
||||
: env_(loader, threading, inference), last_prob_(0.0f) {}
|
||||
GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
|
|
@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) {
|
|||
py::class_<GemmaModel>(mod, "GemmaModel")
|
||||
.def(py::init([](const std::string& tokenizer, const std::string& weights,
|
||||
size_t max_threads) {
|
||||
const gcpp::LoaderArgs loader(tokenizer, weights);
|
||||
gcpp::ThreadingArgs threading;
|
||||
threading.max_lps = max_threads;
|
||||
|
||||
gcpp::InferenceArgs inference;
|
||||
inference.max_generated_tokens = 512;
|
||||
auto gemma =
|
||||
std::make_unique<GemmaModel>(loader, threading, inference);
|
||||
|
||||
const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights),
|
||||
threading, inference);
|
||||
auto gemma = std::make_unique<GemmaModel>(args);
|
||||
if (!gemma->ModelIsLoaded()) {
|
||||
throw std::invalid_argument("Could not load model.");
|
||||
}
|
||||
|
|
|
|||
78
util/args.h
78
util/args.h
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include <algorithm> // std::transform
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h" // Tristate
|
||||
|
|
@ -29,6 +30,56 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// For checking which args were not matched/consumed. Passed to each `*Args`
|
||||
// ctor that parses argc/argv to ensure that their args are tracked, without
|
||||
// requiring global state.
|
||||
class ConsumedArgs {
|
||||
public:
|
||||
ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) {
|
||||
// We assume argc >= 1, because argv[0] is the binary name. That allows us
|
||||
// to signal "called AbortIfUnconsumed" with an empty vector.
|
||||
HWY_ASSERT(!consumed_.empty());
|
||||
}
|
||||
|
||||
~ConsumedArgs() {
|
||||
if (HWY_UNLIKELY(!consumed_.empty())) {
|
||||
HWY_ABORT("AbortIfUnconsumed was not called.");
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyConsumed(size_t idx) {
|
||||
HWY_ASSERT(idx < consumed_.size());
|
||||
HWY_ASSERT(consumed_[idx] == 0);
|
||||
consumed_[idx] = 1;
|
||||
}
|
||||
|
||||
// Returns index of first unconsumed arg, or 0 if none. Also disarms the
|
||||
// warning in the dtor checking whether this/`AbortIfUnconsumed` were called.
|
||||
size_t FirstUnconsumed() {
|
||||
// Ignore argv[0], which is the binary name.
|
||||
for (size_t i = 1; i < consumed_.size(); ++i) {
|
||||
if (HWY_UNLIKELY(consumed_[i] == 0)) {
|
||||
consumed_.clear();
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
consumed_.clear();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void AbortIfUnconsumed() {
|
||||
const size_t i = FirstUnconsumed();
|
||||
if (HWY_UNLIKELY(i != 0)) {
|
||||
HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
char** argv_;
|
||||
std::vector<uint8_t> consumed_;
|
||||
};
|
||||
|
||||
// Args is a class that provides a ForEach member function which visits each of
|
||||
// its member variables. ArgsBase provides functions called by Args to
|
||||
// initialize values to their defaults (passed as an argument to the visitor),
|
||||
|
|
@ -93,12 +144,14 @@ class ArgsBase {
|
|||
// consider adding a hash-map to speed this up.
|
||||
class ParseVisitor {
|
||||
public:
|
||||
ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {}
|
||||
ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed)
|
||||
: argc_(argc), argv_(argv), consumed_(consumed) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T& t, const char* name, const T& /*init*/,
|
||||
const char* /*help*/, int /*print_verbosity*/ = 0) const {
|
||||
const std::string prefixed = std::string("--") + name;
|
||||
const std::string prefixed_eq = prefixed + "=";
|
||||
for (int i = 1; i < argc_; ++i) {
|
||||
if (std::string(argv_[i]) == prefixed) {
|
||||
if (i + 1 >= argc_) {
|
||||
|
|
@ -107,6 +160,16 @@ class ArgsBase {
|
|||
if (!SetValue(argv_[i + 1], t)) {
|
||||
HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]);
|
||||
}
|
||||
consumed_.NotifyConsumed(i);
|
||||
consumed_.NotifyConsumed(i + 1);
|
||||
return;
|
||||
}
|
||||
if (std::string(argv_[i]).find(prefixed_eq) == 0) {
|
||||
const char* value = argv_[i] + prefixed_eq.length();
|
||||
if (!SetValue(value, t)) {
|
||||
HWY_ABORT("Invalid value for %s, got %s\n", name, value);
|
||||
}
|
||||
consumed_.NotifyConsumed(i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -173,8 +236,9 @@ class ArgsBase {
|
|||
}
|
||||
}
|
||||
|
||||
int argc_;
|
||||
char** argv_;
|
||||
const int argc_;
|
||||
char** const argv_;
|
||||
ConsumedArgs& consumed_;
|
||||
}; // ParseVisitor
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -203,15 +267,15 @@ class ArgsBase {
|
|||
ForEach(visitor);
|
||||
}
|
||||
|
||||
void Parse(int argc, char* argv[]) {
|
||||
ParseVisitor visitor(argc, argv);
|
||||
void Parse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
ParseVisitor visitor(argc, argv, consumed);
|
||||
ForEach(visitor);
|
||||
}
|
||||
|
||||
// For convenience, enables single-line constructor.
|
||||
void InitAndParse(int argc, char* argv[]) {
|
||||
void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
Init();
|
||||
Parse(argc, argv);
|
||||
Parse(argc, argv, consumed);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ namespace gcpp {
|
|||
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||
|
||||
// Multiplier so a u64 occupies an entire cache line; avoids false sharing.
|
||||
HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
|
||||
|
||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||
|
||||
static inline const char* ToString(Tristate t) {
|
||||
|
|
|
|||
33
util/mat.h
33
util/mat.h
|
|
@ -181,7 +181,15 @@ class MatPtr : public IFields {
|
|||
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
||||
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
||||
bool SameShape(const MatPtr& other) const {
|
||||
return Rows() == other.Rows() && cols_ == other.cols_;
|
||||
return Rows() == other.Rows() && Cols() == other.Cols();
|
||||
}
|
||||
void DebugCheckSameShape(const MatPtr& other) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (!SameShape(other)) {
|
||||
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
|
||||
Rows(), Cols(), other.Rows(), Cols());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
||||
// will return this value. Used to set the actual number of rows for
|
||||
|
|
@ -284,6 +292,9 @@ class MatPtrT : public MatPtr {
|
|||
public:
|
||||
using T = MatT;
|
||||
|
||||
// Default constructor for use with uninitialized views.
|
||||
MatPtrT() = default;
|
||||
|
||||
// Called by `MatStorageT`.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
|
|
@ -296,7 +307,10 @@ class MatPtrT : public MatPtr {
|
|||
if (GetType() == Type::kUnknown) {
|
||||
SetType(TypeEnum<MatT>());
|
||||
} else {
|
||||
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
||||
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
|
||||
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
|
||||
TypeName<MatT>(), TypeName(other.GetType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
|
|
@ -440,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
|||
}
|
||||
}
|
||||
|
||||
// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
if (base->GetType() == Type::kF32) {
|
||||
const MatPtrT<float> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
const MatPtrT<BF16> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
|
||||
|
|
|
|||
|
|
@ -19,20 +19,51 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include "util/basics.h" // RngStream
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/stats.h"
|
||||
#include "hwy/tests/test_util.h" // RandomState
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Excludes outliers; we might not have enough samples for a reliable mode.
|
||||
HWY_INLINE double TrimmedMean(double* seconds, size_t num) {
|
||||
std::sort(seconds, seconds + num);
|
||||
double sum = 0;
|
||||
int count = 0;
|
||||
for (size_t i = num / 4; i < num / 2; ++i) {
|
||||
sum += seconds[i];
|
||||
count += 1;
|
||||
}
|
||||
HWY_DASSERT(num != 0);
|
||||
return sum / count;
|
||||
}
|
||||
|
||||
// Returns normalized value in [-1, 1).
|
||||
HWY_INLINE float RandomFloat(RngStream& rng) {
|
||||
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
|
||||
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
|
||||
const uint32_t representation = exp | (rng() & mantissa_mask);
|
||||
const float f12 = hwy::BitCastScalar<float>(representation);
|
||||
HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
|
||||
const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
|
||||
HWY_DASSERT(-1.0f <= f && f < 1.0f);
|
||||
return f;
|
||||
}
|
||||
|
||||
// Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights)
|
||||
// using the central limit theorem. Avoid std::normal_distribution for
|
||||
// consistent cross-platform output.
|
||||
// TODO: use RngStream instead of RandomState.
|
||||
HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
|
||||
uint64_t sum = 0;
|
||||
constexpr int kReps = 40;
|
||||
|
|
@ -71,6 +102,25 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) {
|
|||
HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillMatPtrT(MatPtrT<T>& mat) {
|
||||
for (int i = 0; i < mat.Rows(); ++i) {
|
||||
for (int j = 0; j < mat.Cols(); ++j) {
|
||||
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintMatPtr(MatPtrT<T> mat) {
|
||||
for (int i = 0; i < mat.Rows(); ++i) {
|
||||
for (int j = 0; j < mat.Cols(); ++j) {
|
||||
std::cerr << mat.Row(i)[j] << " ,";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
|
||||
|
|
|
|||
|
|
@ -187,7 +187,9 @@ class NestedPools {
|
|||
// functions below.
|
||||
class IndexRangePartition {
|
||||
public:
|
||||
IndexRangePartition() = default; // for MMPartitions
|
||||
explicit IndexRangePartition(size_t single_task)
|
||||
: range_(0, single_task), task_size_(single_task), num_tasks_(1) {}
|
||||
|
||||
IndexRangePartition(const IndexRange& range, const size_t task_size)
|
||||
: range_(range), task_size_(static_cast<uint32_t>(task_size)) {
|
||||
const uint32_t num = static_cast<uint32_t>(range.Num());
|
||||
|
|
@ -262,43 +264,6 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
|
|||
return IndexRangePartition(range, size);
|
||||
}
|
||||
|
||||
// Parallel-for over a single range. This takes care of translating the task
|
||||
// index to a range.
|
||||
template <class Func>
|
||||
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
|
||||
hwy::pool::Caller caller, const Func& func) {
|
||||
const size_t num_tasks = get1.NumTasks();
|
||||
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
const IndexRange range1 = get1.Range(task);
|
||||
func(range1, thread);
|
||||
});
|
||||
}
|
||||
|
||||
// Parallel-for over the Cartesian product of the two sets of ranges. This
|
||||
// combines their indices into a single 'task' so they can be executed by one
|
||||
// `pool.Run`, which increases the amount of work available to workers and
|
||||
// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with
|
||||
// the two ranges and the thread index within `pool`.
|
||||
template <class Func>
|
||||
void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||
const IndexRangePartition& get2,
|
||||
hwy::ThreadPool& pool, hwy::pool::Caller caller,
|
||||
const Func& func) {
|
||||
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
|
||||
|
||||
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
|
||||
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
HWY_DASSERT(task < (uint64_t{1} << 32));
|
||||
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
|
||||
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
|
||||
HWY_DASSERT(idx1 < get1.NumTasks());
|
||||
HWY_DASSERT(idx2 < get2.NumTasks());
|
||||
const IndexRange range1 = get1.Range(idx1);
|
||||
const IndexRange range2 = get2.Range(idx2);
|
||||
func(range1, range2, thread);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
|
|
|||
|
|
@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
|||
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
|
||||
num_workers * 5, num_workers * 20};
|
||||
|
||||
// Count tasks executed to ensure workers aren't optimized out. One per
|
||||
// cache line to avoid false sharing.
|
||||
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t);
|
||||
|
||||
std::vector<size_t> counters(num_workers * kSizePerLine);
|
||||
size_t prev_total = 0; // avoids having to reset counters.
|
||||
// Count tasks executed to ensure workers aren't optimized out.
|
||||
std::vector<uint64_t> counters(num_workers * kU64PerLine);
|
||||
uint64_t prev_total = 0; // avoids having to reset counters.
|
||||
|
||||
hwy::RandomState rng;
|
||||
for (size_t rep = 0; rep < 500; ++rep) {
|
||||
|
|
@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
|||
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(begin <= task && task < end);
|
||||
HWY_ASSERT(thread < num_workers);
|
||||
counters[thread * kSizePerLine]++;
|
||||
counters[thread * kU64PerLine]++;
|
||||
});
|
||||
|
||||
// Reduce count and ensure it matches the expected number of tasks.
|
||||
size_t total = 0;
|
||||
uint64_t total = 0;
|
||||
for (size_t i = 0; i < num_workers; ++i) {
|
||||
total += counters[i * kSizePerLine];
|
||||
total += counters[i * kU64PerLine];
|
||||
}
|
||||
const size_t expected = end - begin;
|
||||
HWY_ASSERT(total == prev_total + expected);
|
||||
|
|
@ -100,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
|||
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||
cache_info(topology),
|
||||
allocator(topology, cache_info, args.bind != Tristate::kFalse),
|
||||
pools(topology, allocator, args.max_threads, args.pin) {
|
||||
pools(topology, allocator, args.max_threads, args.pin),
|
||||
tensor_output(args.tensor_output) {
|
||||
PROFILER_ZONE("Startup.ThreadingContext autotune");
|
||||
TunePools(hwy::PoolWaitMode::kSpin, *this);
|
||||
// kBlock is the default, hence set/tune it last.
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <stdint.h>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "io/io.h" // Path
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
|
|
@ -37,7 +38,9 @@ namespace gcpp {
|
|||
// Optional arguments for `ThreadingContext` from the command line.
|
||||
class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||
public:
|
||||
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) {
|
||||
InitAndParse(argc, argv, consumed);
|
||||
}
|
||||
ThreadingArgs() { Init(); };
|
||||
|
||||
// For BoundedTopology:
|
||||
|
|
@ -55,6 +58,8 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
Path tensor_output; // empty, or directory for tensor output
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
// These can be used to partition CPU packages/sockets and their
|
||||
|
|
@ -85,6 +90,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
|
||||
visitor(bind, "bind", Tristate::kDefault,
|
||||
"Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
|
||||
visitor(tensor_output, "tensor_output", Path(),
|
||||
"Empty, or directory for tensor output.", 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -100,7 +108,7 @@ struct ThreadingContext {
|
|||
|
||||
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
||||
// the current thread is running on one thread per cluster, which happens
|
||||
// when `ParallelismStrategy` is `kAcrossClusters`.
|
||||
// when `Parallelism` is `kAcrossClusters`.
|
||||
size_t Worker(size_t cluster_idx) const {
|
||||
return cluster_idx * pools.MaxWorkersPerCluster();
|
||||
}
|
||||
|
|
@ -124,13 +132,15 @@ struct ThreadingContext {
|
|||
|
||||
// Per-package/cluster/within cluster pools of threads, matching `topology`.
|
||||
NestedPools pools;
|
||||
|
||||
Path tensor_output; // used by `TensorStats::Notify`.
|
||||
};
|
||||
|
||||
#define GCPP_ZONE(ctx, global_idx, zone_enum) \
|
||||
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
|
||||
|
||||
// Describes the strategy for distributing parallel work across cores.
|
||||
enum class ParallelismStrategy : uint8_t {
|
||||
enum class Parallelism : uint8_t {
|
||||
// Execute using a single-threaded loop on the calling thread. The `worker`
|
||||
// index passed to the user's `Func` is unique across clusters.
|
||||
kNone,
|
||||
|
|
@ -154,56 +164,110 @@ enum class ParallelismStrategy : uint8_t {
|
|||
kHierarchical,
|
||||
};
|
||||
|
||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||
// over clusters of ONE package, then within each cluster.
|
||||
// Helper functions used to implement `ParallelFor`, also reused in multiple
|
||||
// places. User code should call `ParallelFor` instead, which accepts the more
|
||||
// convenient `Callers` enum.
|
||||
//
|
||||
// These call `func(task, worker)` for each task in `[0, num_tasks)`.
|
||||
|
||||
// NOTE: the worker argument is actually the `cluster_idx`, so that `Func` can
|
||||
// pass that to `ParallelForWithinCluster`.
|
||||
template <class Func>
|
||||
void ParallelForAcrossClusters(size_t num_tasks, ThreadingContext& ctx,
|
||||
hwy::pool::Caller caller, const Func& func) {
|
||||
ctx.pools.AllClusters().Run(
|
||||
0, num_tasks, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
void ParallelForWithinCluster(size_t num_tasks, ThreadingContext& ctx,
|
||||
size_t cluster_idx, hwy::pool::Caller caller,
|
||||
const Func& func) {
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
ctx.pools.Cluster(cluster_idx)
|
||||
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
|
||||
func(task, cluster_base + worker);
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(range, cluster_idx)`, for passing to `*WithinCluster`.
|
||||
template <class Func>
|
||||
void ParallelPartitionAcrossClusters(const IndexRange range,
|
||||
size_t task_multiple, size_t inner_tasks,
|
||||
ThreadingContext& ctx,
|
||||
hwy::pool::Caller caller,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const IndexRangePartition ranges = StaticPartition(
|
||||
range, ctx.pools.NumClusters() * inner_tasks, task_multiple);
|
||||
ParallelForAcrossClusters(ranges.NumTasks(), ctx, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) {
|
||||
func(ranges.Range(task), cluster_idx);
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(range, worker)`.
|
||||
template <class Func>
|
||||
void ParallelPartitionWithinCluster(const IndexRange range,
|
||||
size_t task_multiple, size_t inner_tasks,
|
||||
ThreadingContext& ctx, size_t cluster_idx,
|
||||
hwy::pool::Caller caller,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t num_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
|
||||
const IndexRangePartition ranges =
|
||||
StaticPartition(range, num_workers * inner_tasks, task_multiple);
|
||||
ParallelForWithinCluster(
|
||||
ranges.NumTasks(), ctx, cluster_idx, caller,
|
||||
[&](uint64_t task, size_t worker) { func(ranges.Range(task), worker); });
|
||||
}
|
||||
|
||||
// Parallelizes across clusters, then within each cluster.
|
||||
template <class Func>
|
||||
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
|
||||
Callers callers, const Func& func) {
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
||||
// there is only one cluster.
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(0);
|
||||
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
||||
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
func(task, thread);
|
||||
});
|
||||
|
||||
// If at most one task per cluster worker, run on a single cluster to avoid
|
||||
// the expensive cross-cluster barrier.
|
||||
{
|
||||
const size_t cluster_idx = 0;
|
||||
const size_t cluster_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
|
||||
if (HWY_UNLIKELY(num_tasks <= cluster_workers)) {
|
||||
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||
func);
|
||||
}
|
||||
}
|
||||
|
||||
// Assign each cluster a sub-range.
|
||||
const IndexRangePartition ranges =
|
||||
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
|
||||
ParallelizeOneRange(ranges, all_clusters, caller,
|
||||
[&](const IndexRange& range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster =
|
||||
ctx.pools.Cluster(cluster_idx);
|
||||
const size_t cluster_base =
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
cluster.Run(range.begin(), range.end(), caller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
func(task, cluster_base + thread);
|
||||
ParallelPartitionAcrossClusters(
|
||||
IndexRange(0, num_tasks), /*task_multiple=*/1, /*inner_tasks=*/1, ctx,
|
||||
caller, [&](const IndexRange& cluster_range, size_t cluster_idx) {
|
||||
ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, caller,
|
||||
[&](uint64_t i, size_t worker) {
|
||||
func(cluster_range.begin() + i, worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
||||
// number/type of workers determined by `parallelism`. `cluster_idx` is for
|
||||
// `parallelism == kWithinCluster`, and should be 0 if unknown.
|
||||
// number/type of workers determined by `parallelism`. NOTE: worker is actually
|
||||
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
|
||||
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
|
||||
template <class Func>
|
||||
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||
void ParallelFor(Parallelism parallelism, size_t num_tasks,
|
||||
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
|
||||
if (cluster_idx != 0) {
|
||||
// If already running across clusters, only use within-cluster modes.
|
||||
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
|
||||
parallelism == ParallelismStrategy::kWithinCluster);
|
||||
HWY_DASSERT(parallelism == Parallelism::kNone ||
|
||||
parallelism == Parallelism::kWithinCluster);
|
||||
}
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone: {
|
||||
case Parallelism::kNone: {
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
for (size_t task = 0; task < num_tasks; ++task) {
|
||||
func(task, worker);
|
||||
|
|
@ -211,40 +275,28 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
return;
|
||||
}
|
||||
|
||||
case ParallelismStrategy::kAcrossClusters:
|
||||
return ctx.pools.AllClusters().Run(
|
||||
0, num_tasks, caller,
|
||||
case Parallelism::kAcrossClusters:
|
||||
return ParallelForAcrossClusters(
|
||||
num_tasks, ctx, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||
|
||||
case ParallelismStrategy::kWithinCluster: {
|
||||
// Ensure the worker argument is unique across clusters, because it is
|
||||
// used for TLS indexing for example in profiler.h.
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
return ctx.pools.Cluster(cluster_idx)
|
||||
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
|
||||
func(task, base + worker);
|
||||
});
|
||||
}
|
||||
case Parallelism::kWithinCluster:
|
||||
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||
func);
|
||||
|
||||
case ParallelismStrategy::kFlat: {
|
||||
// Check for single cluster; if not, we must compute `cluster_base` for
|
||||
// consistent and non-overlapping worker indices.
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
if (num_clusters == 1) {
|
||||
return ctx.pools.Cluster(cluster_idx)
|
||||
.Run(0, num_tasks, caller,
|
||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||
case Parallelism::kFlat:
|
||||
// Choose a single pool: the only cluster, or across all clusters
|
||||
// (slower synchronization, but more memory bandwidth)
|
||||
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
|
||||
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||
func);
|
||||
}
|
||||
|
||||
return all_clusters.Run(0, num_tasks, caller,
|
||||
return ParallelForAcrossClusters(num_tasks, ctx, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) {
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
func(task, worker);
|
||||
func(task, ctx.Worker(cluster_idx));
|
||||
});
|
||||
}
|
||||
|
||||
case ParallelismStrategy::kHierarchical:
|
||||
case Parallelism::kHierarchical:
|
||||
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -202,59 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(ThreadingTest, TestParallelizeOneRange) {
|
||||
const IndexRange range(0, 10);
|
||||
const IndexRangePartition partition = StaticPartition(range, 2, 4);
|
||||
hwy::ThreadPool null_pool(0);
|
||||
size_t calls = 0;
|
||||
ParallelizeOneRange(partition, null_pool, kCaller,
|
||||
[&](const IndexRange& range, size_t) {
|
||||
if (++calls == 1) {
|
||||
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
|
||||
} else {
|
||||
HWY_ASSERT(range.begin() == 8 && range.end() == 10);
|
||||
}
|
||||
});
|
||||
HWY_ASSERT(calls == 2);
|
||||
}
|
||||
|
||||
TEST(ThreadingTest, TestParallelizeTwoRanges) {
|
||||
const IndexRangePartition partition1 =
|
||||
StaticPartition(IndexRange(0, 10), 2, 4);
|
||||
const IndexRangePartition partition2 =
|
||||
MaxSizePartition(IndexRange(128, 256), 32, 32);
|
||||
HWY_ASSERT(partition2.NumTasks() == 4);
|
||||
hwy::ThreadPool null_pool(0);
|
||||
{
|
||||
size_t calls = 0;
|
||||
ParallelizeTwoRanges(
|
||||
partition1, partition2, null_pool, kCaller,
|
||||
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
|
||||
++calls;
|
||||
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||
HWY_ASSERT(range2.begin() % 32 == 0);
|
||||
HWY_ASSERT(range2.Num() % 32 == 0);
|
||||
});
|
||||
HWY_ASSERT(calls == 2 * 4);
|
||||
}
|
||||
|
||||
// Also swap order to test Remainder() logic.
|
||||
{
|
||||
size_t calls = 0;
|
||||
ParallelizeTwoRanges(
|
||||
partition2, partition1, null_pool, kCaller,
|
||||
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
|
||||
++calls;
|
||||
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||
HWY_ASSERT(range2.begin() % 32 == 0);
|
||||
HWY_ASSERT(range2.Num() % 32 == 0);
|
||||
});
|
||||
HWY_ASSERT(calls == 2 * 4);
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t);
|
||||
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];
|
||||
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine];
|
||||
|
||||
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||
// Governs duration of test; avoid timeout in debug builds.
|
||||
|
|
@ -268,7 +216,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
const double t0 = hwy::platform::Now();
|
||||
for (size_t reps = 0; reps < 1200; ++reps) {
|
||||
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
outputs[thread * kU64PerLine] = base + thread;
|
||||
});
|
||||
hwy::PreventElision(outputs[base]);
|
||||
if (pool.AutoTuneComplete()) break;
|
||||
|
|
@ -309,7 +257,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
const uint64_t t0 = hwy::timer::Start();
|
||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
outputs[thread * kU64PerLine] = base + thread;
|
||||
});
|
||||
const uint64_t t1 = hwy::timer::Stop();
|
||||
times.push_back(t1 - t0);
|
||||
|
|
@ -319,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
const uint64_t t0 = hwy::timer::Start();
|
||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
outputs[thread * kU64PerLine] = base + thread;
|
||||
});
|
||||
const uint64_t t1 = hwy::timer::Start();
|
||||
times.push_back(t1 - t0);
|
||||
|
|
@ -366,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) {
|
|||
|
||||
// Verify outputs to ensure the measured code is not a no-op.
|
||||
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
|
||||
HWY_ASSERT(outputs[lp * kU64PerThread] >= 1);
|
||||
HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers());
|
||||
for (size_t i = 1; i < kU64PerThread; ++i) {
|
||||
HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0);
|
||||
HWY_ASSERT(outputs[lp * kU64PerLine] >= 1);
|
||||
HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers());
|
||||
for (size_t i = 1; i < kU64PerLine; ++i) {
|
||||
HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
130
util/topology.cc
130
util/topology.cc
|
|
@ -21,12 +21,14 @@
|
|||
#include <vector>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Returns set of LPs available for use.
|
||||
static LPS EnabledLPs(const BoundedSlice& lp_slice) {
|
||||
LPS enabled_lps;
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
|
||||
// Thread-safe caching during the first call because subsequent pinning
|
||||
// overwrites the main thread's affinity.
|
||||
|
|
@ -35,6 +37,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
|
|||
if (!GetThreadAffinity(affinity)) affinity = LPS();
|
||||
return affinity;
|
||||
}();
|
||||
|
||||
if (HWY_LIKELY(affinity.Any())) {
|
||||
// To honor taskset/numactl *and* the users's `lp_slice`, we interpret
|
||||
// the latter as a slice of the 1-bits of `enabled_lps`. Note that this
|
||||
|
|
@ -48,18 +51,32 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
|
|||
}
|
||||
++enabled_idx;
|
||||
});
|
||||
} else {
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
}
|
||||
|
||||
if (HWY_UNLIKELY(!enabled_lps.Any())) {
|
||||
// First warn: either about unknown affinity, or no overlap with `lp_slice`.
|
||||
if (!affinity.Any()) {
|
||||
// Do not warn on Apple, where affinity is not supported.
|
||||
if (!HWY_OS_APPLE) {
|
||||
HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps,
|
||||
lp_slice.Num(num_lps));
|
||||
}
|
||||
} else {
|
||||
HWY_WARN("LP slice [%zu, %zu) of initial affinity %zu is empty.",
|
||||
lp_slice.Begin(), lp_slice.End(num_lps), affinity.Count());
|
||||
}
|
||||
|
||||
// Set `enabled_lps` based only on `lp_slice` and total logical processors.
|
||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||
if (lp_slice.Contains(num_lps, lp)) {
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
}
|
||||
|
||||
if (!enabled_lps.Any()) {
|
||||
HWY_WARN("no enabled LPs of total %zu, slice [%zu, %zu).", num_lps,
|
||||
lp_slice.Begin(), lp_slice.End(affinity.Count()));
|
||||
}
|
||||
}
|
||||
|
||||
// Without threading support, only keep the first enabled LP; it might still
|
||||
|
|
@ -72,6 +89,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) {
|
|||
HWY_WARN("Warning, threads not supported, using only the main thread.");
|
||||
}
|
||||
|
||||
HWY_ASSERT(enabled_lps.Any());
|
||||
return enabled_lps;
|
||||
}
|
||||
|
||||
|
|
@ -156,12 +174,13 @@ constexpr size_t kMaxLPsPerCluster = 6;
|
|||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
|
||||
LPS cores;
|
||||
lps.Foreach([&](size_t lp) {
|
||||
if (topology.lps[lp].smt == 0) cores.Set(lp);
|
||||
});
|
||||
return cores.Count();
|
||||
// Returns number of distinct SMT (hyperthreads).
|
||||
static size_t NumSMT(const hwy::Topology& topology) {
|
||||
hwy::BitSet64 smt;
|
||||
for (const hwy::Topology::LP& lp : topology.lps) {
|
||||
smt.Set(lp.smt);
|
||||
}
|
||||
return smt.Count();
|
||||
}
|
||||
|
||||
// tcluster is a modifiable copy of the first cluster in the package.
|
||||
|
|
@ -187,34 +206,66 @@ void BoundedTopology::SplitLargeCluster(const LPS& enabled_lps,
|
|||
}
|
||||
}
|
||||
|
||||
// Main part of ctor, called when topology is known.
|
||||
bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
|
||||
const size_t tpkg_idx = package_slice_.Begin();
|
||||
HWY_ASSERT(tpkg_idx < topology_.packages.size());
|
||||
const hwy::Topology::Package& tpackage = topology_.packages[tpkg_idx];
|
||||
const std::vector<hwy::Topology::Cluster>& tclusters = tpackage.clusters;
|
||||
using TClusters = std::vector<hwy::Topology::Cluster>;
|
||||
|
||||
// Returns false if no cluster in `tclusters` has any enabled LPs.
|
||||
static bool AnyEnabledLPs(const TClusters& tclusters, const LPS& enabled_lps) {
|
||||
if (HWY_UNLIKELY(tclusters.empty())) {
|
||||
HWY_WARN("Topology: no clusters found in package %zu.", tpkg_idx);
|
||||
HWY_WARN("Topology: no clusters found.");
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t max_tcluster_cores = 0;
|
||||
size_t max_tcluster_lps = 0;
|
||||
for (const hwy::Topology::Cluster& tcluster : tclusters) {
|
||||
const size_t cores = CoresFromLPs(tcluster.lps, topology_);
|
||||
const size_t lps = tcluster.lps.Count();
|
||||
max_tcluster_cores = HWY_MAX(max_tcluster_cores, cores);
|
||||
max_tcluster_lps = HWY_MAX(max_tcluster_lps, lps);
|
||||
bool any_lp_enabled = false;
|
||||
tcluster.lps.Foreach(
|
||||
[&](size_t lp) { any_lp_enabled |= (enabled_lps.Get(lp)); });
|
||||
if (any_lp_enabled) return true;
|
||||
}
|
||||
HWY_ASSERT(max_tcluster_cores != 0);
|
||||
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
|
||||
|
||||
// No warning: this can happen if OS affinity limits us to the second package.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns nullptr on failure. Also attempts `1 - tpkg_idx`, which is suitable
|
||||
// for the common case of up to two packages.
|
||||
static const TClusters* GetPackageClusters(const hwy::Topology& topology,
|
||||
size_t tpkg_idx,
|
||||
const LPS& enabled_lps) {
|
||||
const size_t num_packages = topology.packages.size();
|
||||
HWY_ASSERT(tpkg_idx < num_packages);
|
||||
{
|
||||
const TClusters& tclusters = topology.packages[tpkg_idx].clusters;
|
||||
if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters;
|
||||
}
|
||||
|
||||
// Retry with the other package, if any.
|
||||
tpkg_idx ^= 1;
|
||||
if (tpkg_idx == num_packages) return nullptr;
|
||||
{
|
||||
const TClusters& tclusters = topology.packages[tpkg_idx].clusters;
|
||||
if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters;
|
||||
}
|
||||
|
||||
HWY_WARN(
|
||||
"Ignoring topology (%zu tpackages) because no clusters overlap with the "
|
||||
"OS affinity (%zu enabled LPs): ",
|
||||
num_packages, enabled_lps.Count());
|
||||
enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); });
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Main part of ctor, called when topology is known.
|
||||
bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
|
||||
const TClusters* maybe_tclusters =
|
||||
GetPackageClusters(topology_, package_slice_.Begin(), enabled_lps);
|
||||
if (!maybe_tclusters) return false;
|
||||
const TClusters& tclusters = *maybe_tclusters;
|
||||
|
||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||
// have any enabled LPs.
|
||||
clusters_.reserve(cluster_slice_.Num(tclusters.size()));
|
||||
cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) {
|
||||
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
|
||||
Cluster cluster(enabled_lps, topology_.lps, tcluster);
|
||||
Cluster cluster(enabled_lps, topology_.lps, tclusters[cluster_idx]);
|
||||
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(cluster.NumWorkers() != 0)) {
|
||||
|
|
@ -223,14 +274,10 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
|
|||
nodes_.Set(cluster.Node());
|
||||
}
|
||||
});
|
||||
if (HWY_UNLIKELY(clusters_.empty())) {
|
||||
HWY_WARN("Too restrictive cluster_slice or enabled_lps, no clusters left.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kSplitLargeClusters && clusters_.size() == 1 &&
|
||||
enabled_lps.Count() >= 16) {
|
||||
SplitLargeCluster(enabled_lps, tpackage.clusters[0]);
|
||||
SplitLargeCluster(enabled_lps, tclusters[0]);
|
||||
}
|
||||
|
||||
// Sort by descending 'size' so that users who only use one get the largest.
|
||||
|
|
@ -239,20 +286,23 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) {
|
|||
return a.NumWorkers() > b.NumWorkers();
|
||||
});
|
||||
|
||||
// Largest number of enabled workers in any cluster, for `topology_string_`.
|
||||
// This may be less than `max_tcluster_cores` if `enabled_lps` excludes some.
|
||||
size_t max_cluster_workers = 0;
|
||||
for (const Cluster& c : clusters_) {
|
||||
max_cluster_workers = HWY_MAX(max_cluster_workers, c.NumWorkers());
|
||||
// Happens if all LPs are HTs (we checked that at least some LPs are enabled).
|
||||
if (HWY_UNLIKELY(clusters_.empty())) {
|
||||
HWY_WARN(
|
||||
"Ignoring topology - no usable clusters. cluster_slice [%zu, %zu), "
|
||||
"%zu tclusters, %zu tLPs, %zu enabled LPs: ",
|
||||
cluster_slice_.Begin(), cluster_slice_.End(tclusters.size()),
|
||||
tclusters.size(), topology_.lps.size(), enabled_lps.Count());
|
||||
enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); });
|
||||
return false;
|
||||
}
|
||||
HWY_ASSERT(max_cluster_workers <= max_tcluster_cores);
|
||||
// Do not warn about large clusters: GNR has 40.
|
||||
|
||||
const size_t num_smt = NumSMT(topology_);
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)",
|
||||
topology_.packages.size(), tclusters.size(), max_tcluster_cores,
|
||||
max_tcluster_lps / max_tcluster_cores, NumClusters(),
|
||||
max_cluster_workers, nodes_.Count());
|
||||
topology_.packages.size(), tclusters.size(),
|
||||
tclusters[0].lps.Count() / num_smt, num_smt, NumClusters(),
|
||||
clusters_[0].NumWorkers(), nodes_.Count());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class BoundedTopology {
|
|||
|
||||
class Cluster {
|
||||
public:
|
||||
Cluster(const LPS& lps);
|
||||
explicit Cluster(const LPS& lps);
|
||||
Cluster(const LPS& enabled_lps,
|
||||
const std::vector<hwy::Topology::LP>& all_lps,
|
||||
const hwy::Topology::Cluster& tcluster);
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) {
|
|||
return "Gen.SampleTop1";
|
||||
case Zones::kGenSampleTopK:
|
||||
return "Gen.SampleTopK";
|
||||
case Zones::kGenStats:
|
||||
return "Gen.Stats";
|
||||
case Zones::kMMDecompressA:
|
||||
return "MM.DecompressA";
|
||||
case Zones::kMMDispatch:
|
||||
|
|
@ -163,6 +165,8 @@ const char* CallerName(Callers caller) {
|
|||
return "ReadBatches";
|
||||
case Callers::kSampleAndStream:
|
||||
return "SampleAndStream";
|
||||
case Callers::kTensorStats:
|
||||
return "TensorStats";
|
||||
case Callers::kTest: // only for unit tests.
|
||||
return "Test-only!";
|
||||
case Callers::kTunePool:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ enum class Zones { // Keep sorted
|
|||
kGenFFW,
|
||||
kGenSampleTop1,
|
||||
kGenSampleTopK,
|
||||
kGenStats,
|
||||
kMMDecompressA,
|
||||
kMMDispatch,
|
||||
kMMMatMul,
|
||||
|
|
@ -96,6 +97,7 @@ enum class Callers { // Keep sorted
|
|||
kReadAllToBF16,
|
||||
kReadBatches,
|
||||
kSampleAndStream,
|
||||
kTensorStats,
|
||||
kTest, // only for unit tests.
|
||||
kTunePool,
|
||||
kVitDotSoftmax1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue