mirror of https://github.com/google/gemma.cpp.git
Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.
Improved flash_attention to enable profiling using the new zones. PiperOrigin-RevId: 819235421
This commit is contained in:
parent
035273c184
commit
fb6fa793f4
16
BUILD.bazel
16
BUILD.bazel
|
|
@ -111,6 +111,7 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":threading",
|
":threading",
|
||||||
":topology",
|
":topology",
|
||||||
|
":zones",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
|
@ -118,6 +119,15 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "zones",
|
||||||
|
srcs = ["util/zones.cc"],
|
||||||
|
hdrs = ["util/zones.h"],
|
||||||
|
deps = [
|
||||||
|
"@highway//:profiler",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "flash_attention_test",
|
name = "flash_attention_test",
|
||||||
srcs = ["gemma/flash_attention_test.cc"],
|
srcs = ["gemma/flash_attention_test.cc"],
|
||||||
|
|
@ -263,6 +273,7 @@ cc_library(
|
||||||
":model_store",
|
":model_store",
|
||||||
":tensor_info",
|
":tensor_info",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
":zones",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -321,6 +332,7 @@ cc_library(
|
||||||
":matmul_env",
|
":matmul_env",
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
":zones",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:bit_set",
|
"@highway//:bit_set",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -352,6 +364,7 @@ cc_library(
|
||||||
":matmul",
|
":matmul",
|
||||||
":matmul_env",
|
":matmul_env",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
":zones",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -376,6 +389,7 @@ cc_library(
|
||||||
":matmul_env", # MMOptions
|
":matmul_env", # MMOptions
|
||||||
":matmul_static",
|
":matmul_static",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
":zones",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:algo",
|
"@highway//:algo",
|
||||||
"@highway//:bit_set",
|
"@highway//:bit_set",
|
||||||
|
|
@ -431,6 +445,7 @@ cc_test(
|
||||||
":ops",
|
":ops",
|
||||||
":test_util",
|
":test_util",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
":zones",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:test_util",
|
"//compression:test_util",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
|
|
@ -556,6 +571,7 @@ cc_library(
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
|
":zones",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
"//io",
|
"//io",
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,8 @@ set(SOURCES
|
||||||
util/threading.h
|
util/threading.h
|
||||||
util/topology.cc
|
util/topology.cc
|
||||||
util/topology.h
|
util/topology.h
|
||||||
|
util/zones.cc
|
||||||
|
util/zones.h
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add C API sources only when building DLL
|
# Add C API sources only when building DLL
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
#include "util/zones.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -55,8 +56,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT q,
|
const float* HWY_RESTRICT q,
|
||||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Attention.QDotK");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
|
|
@ -175,7 +175,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivations& activations, QBatch& qbatch,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
|
static const auto root_zone =
|
||||||
|
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
|
||||||
|
hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
|
||||||
|
const auto zone =
|
||||||
|
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||||
|
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -60,7 +61,7 @@ static constexpr size_t kNFx8HTileSize = 8;
|
||||||
// possible consecutive elements have the same KV.
|
// possible consecutive elements have the same KV.
|
||||||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
const size_t qbatch_size, ThreadingContext& ctx) {
|
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
|
const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
|
||||||
// Group floats by the number of floats in a cache line.
|
// Group floats by the number of floats in a cache line.
|
||||||
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
||||||
const size_t num_heads = q.Cols() / q_t.Rows();
|
const size_t num_heads = q.Cols() / q_t.Rows();
|
||||||
|
|
@ -95,8 +96,8 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
static const auto zone =
|
const auto zone =
|
||||||
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
|
GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
|
|
@ -158,8 +159,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
float* HWY_RESTRICT att_out, hwy::Profiler& p,
|
float* HWY_RESTRICT att_out, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention");
|
PROFILER_ZONE3(p, worker,
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
|
||||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||||
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
||||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||||
|
|
@ -276,8 +277,8 @@ void TileFlashAttention(
|
||||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
|
PROFILER_ZONE3(p, worker,
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
|
||||||
constexpr int kHTileSize = kNFx8HTileSize;
|
constexpr int kHTileSize = kNFx8HTileSize;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
const DF df;
|
const DF df;
|
||||||
|
|
@ -430,8 +431,8 @@ void TileFlashAttention4(
|
||||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4");
|
PROFILER_ZONE3(p, worker,
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
const DF df;
|
const DF df;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -524,6 +525,21 @@ static size_t RoundToSuitablePowerOf2(size_t n) {
|
||||||
return 32;
|
return 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The vertical tile size is determined by the ability to use tiling and the
|
||||||
|
// target_parallelism. In practice the possible tile sizes in order of
|
||||||
|
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
|
||||||
|
// 16. The final tile size is chosen to be the largest possible that allows
|
||||||
|
// for target_parallelism parallel tasks.
|
||||||
|
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
||||||
|
size_t total_tasks, size_t target_parallelism) {
|
||||||
|
const size_t kMaxEqualK =
|
||||||
|
RoundToSuitablePowerOf2(num_head_groups * num_tokens);
|
||||||
|
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
|
||||||
|
return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
|
||||||
|
? kNF
|
||||||
|
: std::min(kMinTileSize, kMaxEqualK);
|
||||||
|
}
|
||||||
|
|
||||||
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
|
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
|
||||||
// into a single output O[L,D].
|
// into a single output O[L,D].
|
||||||
// Conventional attention first computes A[L,L] = Q . KT
|
// Conventional attention first computes A[L,L] = Q . KT
|
||||||
|
|
@ -582,7 +598,10 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivations& activations, QBatch& qbatch,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention");
|
static const auto root_zone = ctx.profiler.AddZone(
|
||||||
|
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
|
||||||
|
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
|
||||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
||||||
layer, activations, ctx);
|
layer, activations, ctx);
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
|
|
@ -603,17 +622,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
const size_t kNF = hn::Lanes(df);
|
const size_t kNF = hn::Lanes(df);
|
||||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||||
HWY_DASSERT(kNF <= kMaxNF);
|
HWY_DASSERT(kNF <= kMaxNF);
|
||||||
// The vertical tile size is determined by the ability to use tiling and the
|
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens,
|
||||||
// target_parallelism. In practice the possible tile sizes in order of
|
total_tasks, target_parallelism);
|
||||||
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
|
|
||||||
// 16. The final tile size is chosen to be the largest possible that allows
|
|
||||||
// for target_parallelism parallel tasks.
|
|
||||||
const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens);
|
|
||||||
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
|
|
||||||
const size_t kVTileSize =
|
|
||||||
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
|
|
||||||
? kNF
|
|
||||||
: std::min(kMinTileSize, kMaxEqualK);
|
|
||||||
// Only transpose Q if we are using tiling.
|
// Only transpose Q if we are using tiling.
|
||||||
if (kVTileSize == kNF) {
|
if (kVTileSize == kNF) {
|
||||||
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();
|
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,9 @@ namespace gcpp {
|
||||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
|
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
|
||||||
size_t worker); \
|
size_t worker); \
|
||||||
\
|
\
|
||||||
|
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||||
|
size_t total_tasks, size_t target_parallelism); \
|
||||||
|
\
|
||||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
AttentionActivations& activations, QBatch& qbatch, \
|
AttentionActivations& activations, QBatch& qbatch, \
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,6 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
|
||||||
void TestFlashAttention(size_t target_parallelism) {
|
void TestFlashAttention(size_t target_parallelism) {
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
ThreadingContext ctx(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
// hwy::ThreadPool& pool = ctx.pools.Pool();
|
|
||||||
constexpr size_t kOuter = 1024;
|
constexpr size_t kOuter = 1024;
|
||||||
constexpr size_t kInner = 256;
|
constexpr size_t kInner = 256;
|
||||||
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
|
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
|
||||||
|
|
@ -150,9 +149,19 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
// Copy the output to saved_att to allow for comparison.
|
// Copy the output to saved_att to allow for comparison.
|
||||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||||
SetMat(1, attention.q);
|
SetMat(1, attention.q);
|
||||||
|
using DF = hn::ScalableTag<float>;
|
||||||
|
const DF df;
|
||||||
|
const size_t kNF = hn::Lanes(df);
|
||||||
|
const size_t total_tasks =
|
||||||
|
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
|
||||||
|
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
|
||||||
|
total_tasks, target_parallelism);
|
||||||
|
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||||
|
target_parallelism, kNF, kVTileSize);
|
||||||
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
||||||
qbatch, ctx);
|
qbatch, ctx);
|
||||||
AssertClose(attention.att_out, *saved_att);
|
AssertClose(attention.att_out, *saved_att);
|
||||||
|
ctx.profiler.PrintResults();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAttention() {
|
void TestAttention() {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
// Include guard (still compiled once per target)
|
// Include guard (still compiled once per target)
|
||||||
|
|
@ -48,8 +49,7 @@ template <typename T1, typename T2>
|
||||||
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||||
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Activation");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -88,8 +88,7 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
||||||
const IndexRange range_r,
|
const IndexRange range_r,
|
||||||
const IndexRange range_c, const StridedViewBF C2,
|
const IndexRange range_c, const StridedViewBF C2,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.ActivationFused");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
const size_t cols = range_c.Num();
|
const size_t cols = range_c.Num();
|
||||||
HWY_DASSERT(C2.Cols() == cols);
|
HWY_DASSERT(C2.Cols() == cols);
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
#include "util/zones.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -466,14 +467,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||||
// If user provided a sample_func, use it.
|
// If user provided a sample_func, use it.
|
||||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||||
|
|
||||||
static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
|
|
||||||
static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
|
|
||||||
|
|
||||||
// Fast path for top-1 with no accept_token.
|
// Fast path for top-1 with no accept_token.
|
||||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||||
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
|
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
|
||||||
HWY_ATTR -> TokenAndProb {
|
HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE3(ctx.profiler, worker, zone_top1);
|
PROFILER_ZONE3(ctx.profiler, worker,
|
||||||
|
GetProfilerZone(Zones::kGenSampleTop1));
|
||||||
return Top1OfSoftmax(logits);
|
return Top1OfSoftmax(logits);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -481,7 +480,8 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||||
// General case: Softmax with top-k sampling.
|
// General case: Softmax with top-k sampling.
|
||||||
return [&](size_t qi, size_t pos, Logits logits,
|
return [&](size_t qi, size_t pos, Logits logits,
|
||||||
size_t worker) HWY_ATTR -> TokenAndProb {
|
size_t worker) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE3(ctx.profiler, worker, zone_topK);
|
PROFILER_ZONE3(ctx.profiler, worker,
|
||||||
|
GetProfilerZone(Zones::kGenSampleTopK));
|
||||||
// We want a different sequence for each batch element and position.
|
// We want a different sequence for each batch element and position.
|
||||||
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
|
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
|
||||||
RngStream gen(engine, stream);
|
RngStream gen(engine, stream);
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -379,8 +380,7 @@ static void DecompressToBF16(MatPtr& mat,
|
||||||
|
|
||||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
const BlobReader& reader, ThreadingContext& ctx) {
|
const BlobReader& reader, ThreadingContext& ctx) {
|
||||||
static const auto zone =
|
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
|
||||||
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
|
||||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||||
? ParallelismStrategy::kHierarchical
|
? ParallelismStrategy::kHierarchical
|
||||||
|
|
@ -463,7 +463,7 @@ static std::vector<IOBatch> MakeBatches(
|
||||||
static void ReadBatches(const BlobReader& reader,
|
static void ReadBatches(const BlobReader& reader,
|
||||||
const std::vector<IOBatch>& batches,
|
const std::vector<IOBatch>& batches,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
|
||||||
// >5x speedup from parallel reads when cached.
|
// >5x speedup from parallel reads when cached.
|
||||||
ParallelFor(ParallelismStrategy::kHierarchical,
|
ParallelFor(ParallelismStrategy::kHierarchical,
|
||||||
batches.size(), ctx, /*cluster_idx=*/0,
|
batches.size(), ctx, /*cluster_idx=*/0,
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
@ -290,7 +291,7 @@ class MMDecompress {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
const auto zone = GetProfilerZone(Zones::kMMDecompressA);
|
||||||
|
|
||||||
const auto do_range =
|
const auto do_range =
|
||||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||||
|
|
@ -878,9 +879,9 @@ class MMLoops {
|
||||||
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
const MMArgs& args) {
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
|
||||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||||
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
args.env.ctx.Worker(args.options.cluster_idx),
|
||||||
|
GetProfilerZone(Zones::kMMDispatch));
|
||||||
|
|
||||||
DispatchParallelism(
|
DispatchParallelism(
|
||||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||||
|
|
@ -903,7 +904,7 @@ class MMLoops {
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
const MMArgs& args) {
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
const auto zone = GetProfilerZone(Zones::kMMNT);
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
|
|
@ -939,7 +940,7 @@ class MMLoops {
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
const MMArgs& args) {
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
const auto zone = GetProfilerZone(Zones::kMMNT_K);
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
|
|
||||||
|
|
@ -975,7 +976,7 @@ class MMLoops {
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
const MMArgs& args) {
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
const auto zone = GetProfilerZone(Zones::kMMNT_MT);
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||||
|
|
||||||
|
|
@ -1009,7 +1010,7 @@ class MMLoops {
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
const MMArgs& args) {
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
const auto zone = GetProfilerZone(Zones::kMMNT_MT_K);
|
||||||
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||||
|
|
@ -1060,10 +1061,10 @@ template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
|
||||||
const size_t cluster_idx = options.cluster_idx;
|
const size_t cluster_idx = options.cluster_idx;
|
||||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
|
||||||
|
GetProfilerZone(Zones::kMMMatMul));
|
||||||
|
|
||||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||||
|
|
||||||
|
|
@ -1121,10 +1122,10 @@ template <typename TB>
|
||||||
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
||||||
const MatPtrT<TB>& B2, MatMulEnv& env,
|
const MatPtrT<TB>& B2, MatMulEnv& env,
|
||||||
MatPtrT<BF16>& C, MMOptions options) {
|
MatPtrT<BF16>& C, MMOptions options) {
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul");
|
|
||||||
const size_t cluster_idx = options.cluster_idx;
|
const size_t cluster_idx = options.cluster_idx;
|
||||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
|
||||||
|
GetProfilerZone(Zones::kMMTwoMatMul));
|
||||||
|
|
||||||
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
|
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@
|
||||||
#include "util/basics.h" // TokenAndProb, RngStream
|
#include "util/basics.h" // TokenAndProb, RngStream
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/bit_set.h"
|
#include "hwy/bit_set.h"
|
||||||
#include "hwy/contrib/sort/order.h"
|
#include "hwy/contrib/sort/order.h"
|
||||||
|
|
@ -206,8 +207,7 @@ namespace detail {
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.RMSNormMul");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||||
|
|
@ -223,8 +223,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
OT* HWY_RESTRICT out,
|
OT* HWY_RESTRICT out,
|
||||||
const size_t size, hwy::Profiler& p,
|
const size_t size, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.RMSNorm");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -248,8 +247,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
||||||
const size_t size,
|
const size_t size,
|
||||||
hwy::Profiler& p,
|
hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.RMSNormInplace");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -365,8 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.Rope");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -425,8 +422,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.RopeAndMulBy");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -488,8 +484,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
const size_t size,
|
const size_t size,
|
||||||
hwy::Profiler& p,
|
hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.AddFrom");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -568,8 +563,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||||
const size_t size,
|
const size_t size,
|
||||||
hwy::Profiler& p,
|
hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConst");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -587,8 +581,7 @@ template <typename XT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConstTo");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -606,8 +599,7 @@ template <typename XT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -744,8 +736,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
||||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
|
|
@ -1007,8 +998,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
||||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
|
|
@ -1049,8 +1039,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||||
const size_t pos, float* HWY_RESTRICT out,
|
const size_t pos, float* HWY_RESTRICT out,
|
||||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
|
|
@ -1146,8 +1135,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||||
const size_t worker,
|
const size_t worker,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
static const auto zone = p.AddZone("Ops.Softmax");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
HWY_DASSERT(logits.size() != 0);
|
HWY_DASSERT(logits.size() != 0);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
@ -1280,8 +1268,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap));
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "compression/types.h"
|
#include "compression/types.h"
|
||||||
|
#include "util/zones.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -132,6 +133,7 @@ class TestAddFrom {
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleAddFrom(o, e, count);
|
SimpleAddFrom(o, e, count);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
|
@ -180,6 +182,7 @@ class TestMulByConstAndAdd {
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
|
@ -228,6 +231,7 @@ class TestMulByConst {
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SimpleMulByConst(constant, e, count);
|
SimpleMulByConst(constant, e, count);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
|
@ -274,6 +278,7 @@ struct TestMulByConstTo {
|
||||||
hwy::ConvertScalarTo<float>(constant));
|
hwy::ConvertScalarTo<float>(constant));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
|
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
|
||||||
/*worker=*/0);
|
/*worker=*/0);
|
||||||
|
|
||||||
|
|
@ -310,6 +315,7 @@ class TestSoftmax {
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleSoftmax(e, count);
|
SimpleSoftmax(e, count);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
|
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
|
|
@ -437,6 +443,7 @@ void TestRopeAndMulBy() {
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
ThreadingContext ctx(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
hwy::Profiler& p = ctx.profiler;
|
hwy::Profiler& p = ctx.profiler;
|
||||||
|
InitProfilerZones(p);
|
||||||
const size_t worker = 0;
|
const size_t worker = 0;
|
||||||
|
|
||||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||||
|
|
@ -551,6 +558,7 @@ struct TestRMSNorm {
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
for (size_t i = 0; i < kSize; i++) {
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
|
|
@ -585,6 +593,7 @@ struct TestRMSNormInplace {
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarRMSNorm(expected, weight, expected, kSize);
|
ScalarRMSNorm(expected, weight, expected, kSize);
|
||||||
|
InitProfilerZones(hwy::Profiler::Get());
|
||||||
RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(),
|
RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(),
|
||||||
/*worker=*/0);
|
/*worker=*/0);
|
||||||
|
|
||||||
|
|
@ -707,6 +716,7 @@ void TestAllLayerNorm() {
|
||||||
|
|
||||||
void TestSampleTopK() {
|
void TestSampleTopK() {
|
||||||
hwy::Profiler& p = hwy::Profiler::Get();
|
hwy::Profiler& p = hwy::Profiler::Get();
|
||||||
|
InitProfilerZones(p);
|
||||||
const size_t worker = 0;
|
const size_t worker = 0;
|
||||||
const size_t kSize = 52;
|
const size_t kSize = 52;
|
||||||
std::vector<float> logits_vec(kSize);
|
std::vector<float> logits_vec(kSize);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
#include "util/zones.h"
|
||||||
|
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
#if PROFILER_ENABLED
|
||||||
|
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones);
|
||||||
|
|
||||||
|
static const char* kProfilerZoneNames[kNumZones] = {
|
||||||
|
// Keep in sync with Zones enum.
|
||||||
|
"Ops.RMSNormMul",
|
||||||
|
"Ops.RMSNorm",
|
||||||
|
"Ops.RMSNormInplace",
|
||||||
|
"Ops.Rope",
|
||||||
|
"Ops.RopeAndMulBy",
|
||||||
|
"Ops.AddFrom",
|
||||||
|
"Ops.MulByConst",
|
||||||
|
"Ops.MulByConstTo",
|
||||||
|
"Ops.MulByConstAndAdd",
|
||||||
|
"Ops.MulByConstAndAddTile",
|
||||||
|
"Ops.MulByConstAndAddTile4",
|
||||||
|
"Ops.MulByConstAndAddVector",
|
||||||
|
"Ops.Softmax",
|
||||||
|
"Ops.LogitsSoftCap",
|
||||||
|
"FlashAttention.TransposeQ",
|
||||||
|
"FlashAttention.RMSNormAndPositionalEncoding",
|
||||||
|
"FlashAttention.SingleFlashAttention",
|
||||||
|
"FlashAttention.TileFlashAttention",
|
||||||
|
"FlashAttention.TileFlashAttention4",
|
||||||
|
"FlashAttention.FlashAttention",
|
||||||
|
"Gen.Activation",
|
||||||
|
"Gen.ActivationFused",
|
||||||
|
"Gen.SampleTop1",
|
||||||
|
"Gen.SampleTopK",
|
||||||
|
"Gen.Attention.QDotK",
|
||||||
|
"Gen.Attention.DotSoftmaxWeightedSum.par",
|
||||||
|
"Startup.Weights.ReadAllToBF16",
|
||||||
|
"Startup.Weights.ReadBatches",
|
||||||
|
"MM.Dispatch",
|
||||||
|
"MM.MatMul",
|
||||||
|
"MM.TwoMatMul",
|
||||||
|
"MM.DecompressA",
|
||||||
|
"MM.NT",
|
||||||
|
"MM.NT_K",
|
||||||
|
"MM.NT_MT",
|
||||||
|
"MM.NT_MT_K",
|
||||||
|
};
|
||||||
|
|
||||||
|
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void InitProfilerZones(hwy::Profiler& profiler) {
|
||||||
|
#if PROFILER_ENABLED
|
||||||
|
// Initialize the zone handles. This is done once at startup.
|
||||||
|
for (size_t i = 0; i < kNumZones; ++i) {
|
||||||
|
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) {
|
||||||
|
#if PROFILER_ENABLED
|
||||||
|
return profiler_zone_handles[static_cast<size_t>(zone)];
|
||||||
|
#else
|
||||||
|
return hwy::profiler::ZoneHandle();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
|
||||||
|
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Zones for the profiler.
|
||||||
|
enum class Zones {
|
||||||
|
kOpsRmsNormMul,
|
||||||
|
kOpsRmsNorm,
|
||||||
|
kOpsRmsNormInplace,
|
||||||
|
kOpsRope,
|
||||||
|
kOpsRopeAndMulBy,
|
||||||
|
kOpsAddFrom,
|
||||||
|
kOpsMulByConst,
|
||||||
|
kOpsMulByConstTo,
|
||||||
|
kOpsMulByConstAndAdd,
|
||||||
|
kOpsMulByConstAndAddTile,
|
||||||
|
kOpsMulByConstAndAddTile4,
|
||||||
|
kOpsMulByConstAndAddVector,
|
||||||
|
kOpsSoftmax,
|
||||||
|
kOpsLogitsSoftCap,
|
||||||
|
kFlashAttentionTransposeQ,
|
||||||
|
kFlashAttentionRmsNormAndPositionalEncoding,
|
||||||
|
kFlashAttentionSingleFlashAttention,
|
||||||
|
kFlashAttentionTileFlashAttention,
|
||||||
|
kFlashAttentionTileFlashAttention4,
|
||||||
|
kFlashAttentionFlashAttention,
|
||||||
|
kGenActivation,
|
||||||
|
kGenActivationFused,
|
||||||
|
kGenSampleTop1,
|
||||||
|
kGenSampleTopK,
|
||||||
|
kGenAttentionQDotK,
|
||||||
|
kGenAttentionDotSoftmaxWeightedSumPar,
|
||||||
|
kStartupWeightsReadAllToBF16,
|
||||||
|
kStartupWeightsReadBatches,
|
||||||
|
kMMDispatch,
|
||||||
|
kMMMatMul,
|
||||||
|
kMMTwoMatMul,
|
||||||
|
kMMDecompressA,
|
||||||
|
kMMNT,
|
||||||
|
kMMNT_K,
|
||||||
|
kMMNT_MT,
|
||||||
|
kMMNT_MT_K,
|
||||||
|
kNumZones
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initializes the profiler zones. Must be called before any other profiler
|
||||||
|
// functions.
|
||||||
|
void InitProfilerZones(hwy::Profiler& profiler);
|
||||||
|
|
||||||
|
// Returns the zone handle for the given zone enum value.
|
||||||
|
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone);
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
|
||||||
Loading…
Reference in New Issue