Rewrote flash attention to use BF16, transpose k and v, rewrote the task distribution, increase parallelism on decode, and use double the registers for the core of flash attention.

PiperOrigin-RevId: 868146247
This commit is contained in:
Ray Smith 2026-02-10 07:55:17 -08:00 committed by Copybara-Service
parent 7e5310b908
commit 7b55d41f46
28 changed files with 6662 additions and 1194 deletions

View File

@ -547,6 +547,7 @@ cc_library(
deps = [
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",
@ -594,6 +595,11 @@ cc_test(
INTERNAL_DEPS = []
cc_library(
name = "flash_structs",
hdrs = ["gemma/flash_structs.h"],
)
cc_library(
name = "attention",
srcs = [
@ -603,7 +609,6 @@ cc_library(
hdrs = [
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
],
textual_hdrs = [
"gemma/gemma-inl.h",
@ -612,6 +617,7 @@ cc_library(
":activations",
":basics",
":configs",
":flash_structs",
":kv_cache",
":mat",
":matmul",
@ -822,6 +828,38 @@ cc_test(
],
)
cc_test(
name = "wheat_from_chaff_test",
srcs = ["evals/wheat_from_chaff_test.cc"],
data = [
"evals/testdata/google/big_bang_theory.txt",
"evals/testdata/google/black_hole.txt",
"evals/testdata/google/general_relativity.txt",
"evals/testdata/google/qed.txt",
"evals/testdata/holiday_story.txt",
"evals/testdata/quark_1.txt",
"evals/testdata/quark_2.txt",
"evals/testdata/special_relativity.txt",
"evals/testdata/standard_model.txt",
],
linkstatic = True,
# Requires model files
tags = [
"local",
"manual",
"no_tap",
],
deps = [
":benchmark_helper",
":configs",
":gemma_lib",
"@googletest//:gtest_main", # buildcleaner: keep
"//io",
"@highway//:abort_header_only",
"@highway//:hwy_test_util",
],
)
cc_binary(
name = "gemma",
srcs = ["gemma/run.cc"],

View File

@ -150,7 +150,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
QueryResult GemmaEnv::QueryModel(const std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
auto result = QueryModel(prompt);
fprintf(stderr, "prompt size: %zu, response size: %zu, total tokens: %zu\n",
prompt.size(), result.tokens_generated - prompt.size(),
result.tokens_generated);
return result;
}
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(

View File

@ -62,6 +62,8 @@ class GemmaEnv {
static_cast<size_t>(max_generated_tokens);
}
void PrintProfileResults() { ctx_.profiler.PrintResults(); }
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens));

View File

@ -37,7 +37,8 @@ GemmaEnv* s_env = nullptr;
class GemmaBatchBench : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
const std::vector<std::string>& inputs, AttentionImpl attention_impl) {
s_env->MutableConfig().attention_impl = attention_impl;
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 2;
std::vector<std::string> replies;
@ -128,16 +129,19 @@ std::vector<std::string> GenerateInputs() {
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
s_env->SetMaxGeneratedTokens(12);
const std::vector<std::string> inputs = GenerateInputs();
// Run multiple times so that auto-tuning is closer to complete.
for (size_t rep = 0; rep < 4; ++rep) {
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
responses[i].c_str());
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
for (const AttentionImpl mode : modes) {
// Run multiple times so that auto-tuning is closer to complete.
fprintf(stderr, "Testing mode %s\n", GetAttentionImplName(mode).c_str());
for (size_t rep = 0; rep < 4; ++rep) {
std::vector<std::string> responses = BatchGemmaReply(inputs, mode);
for (size_t i = 0;
i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) {
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
}
PROFILER_PRINT_RESULTS();
}
}

10
evals/testdata/holiday_story.txt vendored Normal file
View File

@ -0,0 +1,10 @@
Albert and Marcia were on holiday. Their parents had brought them to the beach.
Albert was generally unimpressed with beaches, as he would rather explore a dark forest and see the variety of mosses and fungi that grow in the damp conditions.
On the other hand, Marcia loved to build enormous sand castles.
Albert enjoyed collecting limpet shells to decorate the outer walls of the turrets, which he secretly thought made them look like daleks.
Whilst digging sand for building, Marcia always liked to dig deep, to see if she could get to water coming through the sand from the sea.
When the castle was nearly complete, and Marcia needed more sand, she hit a large piece of rusty metal.
Curious as to what it was, Marcia kept digging to try to expose all of it, but it was very big and hard to get at as it was so deep in the sand.
Excited by the prospect of finding something unusual in the sand, Albert joined in to help dig out the entire object.
Almost an hour later, they had exposed most of a ships anchor.
During the excavation a crowd on onlookers had formed around them, who then proceeded to take selfies in front of the unusual piece of beach litter.

21
evals/testdata/quark_1.txt vendored Normal file
View File

@ -0,0 +1,21 @@
Text from https://en.wikipedia.org/wiki/Quark is licensed under Creative Commons Attribution-ShareAlike 4.0 License; (https://en.wikipedia.org/wiki/Wikipedia:Text_of_the_Creative_Commons_Attribution-ShareAlike_4.0_International_License)
Quark
From Wikipedia, the free encyclopedia
(Redirected from Quarks)
This article is about the elementary particle and its antiparticle. For other uses, see Quark (disambiguation).
A quark (/ˈkwɔːrk, ˈkwɑːrk/ ⓘ) is a type of elementary particle and a fundamental constituent of matter. Quarks combine to form composite particles called hadrons, the most stable of which are protons and neutrons, the components of atomic nuclei.[1] All commonly observable matter is composed of up quarks, down quarks and electrons. Owing to a phenomenon known as color confinement, quarks are never found in isolation; they can be found only within hadrons, which include baryons (such as protons and neutrons) and mesons, or in quarkgluon plasmas.[2][3][nb 1] For this reason, much of what is known about quarks has been drawn from observations of hadrons.
Quarks have various intrinsic properties, including electric charge, mass, color charge, and spin. They are the only elementary particles in the Standard Model of particle physics to experience all four fundamental interactions, also known as fundamental forces (electromagnetism, gravitation, strong interaction, and weak interaction), as well as the only known particles whose electric charges are not integer multiples of the elementary charge.
There are six types, known as flavors, of quarks: up, down, charm, strange, top, and bottom.[4] Up and down quarks have the lowest masses of all quarks. The heavier quarks rapidly change into up and down quarks through a process of particle decay: the transformation from a higher mass state to a lower mass state. Because of this, up and down quarks are generally stable and the most common in the universe, whereas strange, charm, bottom, and top quarks can only be produced in high energy collisions (such as those involving cosmic rays and in particle accelerators). For every quark flavor there is a corresponding type of antiparticle, known as an antiquark, that differs from the quark only in that some of its properties (such as the electric charge) have equal magnitude but opposite sign.
The quark model was independently proposed by physicists Murray Gell-Mann and George Zweig in 1964.[5] Quarks were introduced as parts of an ordering scheme for hadrons, and there was little evidence for their physical existence until deep inelastic scattering experiments at the Stanford Linear Accelerator Center in 1968.[6][7] Accelerator program experiments have provided evidence for all six flavors. The top quark, first observed at Fermilab in 1995, was the last to be discovered.[5]
Classification
See also: Standard Model
A four-by-four table of particles. Columns are three generations of matter (fermions) and one of forces (bosons). In the first three columns, two rows contain quarks and two leptons. The top two rows' columns contain up (u) and down (d) quarks, charm (c) and strange (s) quarks, top (t) and bottom (b) quarks, and photon (γ) and gluon (g), respectively. The bottom two rows' columns contain electron neutrino (ν sub e) and electron (e), muon neutrino (ν sub μ) and muon (μ), and tau neutrino (ν sub τ) and tau (τ), and Z sup 0 and W sup ± weak force. Mass, charge, and spin are listed for each particle.
Six of the particles in the Standard Model are quarks (shown in purple). Each of the first three columns forms a generation of matter.
The Standard Model is the theoretical framework describing all the known elementary particles. This model contains six flavors of quarks (q), named up (u), down (d), strange (s), charm (c), bottom (b), and top (t).[4] Antiparticles of quarks are called antiquarks, and are denoted by a bar over the symbol for the corresponding quark, such as u for an up antiquark. As with antimatter in general, antiquarks have the same mass, mean lifetime, and spin as their respective quarks, but the electric charge and other charges have the opposite sign.[8]

90
evals/testdata/quark_2.txt vendored Normal file
View File

@ -0,0 +1,90 @@
Text from https://en.wikipedia.org/wiki/Quark is licensed under Creative Commons Attribution-ShareAlike 4.0 License; (https://en.wikipedia.org/wiki/Wikipedia:Text_of_the_Creative_Commons_Attribution-ShareAlike_4.0_International_License)
Quark
From Wikipedia, the free encyclopedia
(Redirected from Quarks)
This article is about the elementary particle and its antiparticle. For other uses, see Quark (disambiguation).
Quark
Three colored balls (symbolizing quarks) connected pairwise by springs (symbolizing gluons), all inside a gray circle (symbolizing a proton). The colors of the balls are red, green, and blue, to parallel each quark's color charge. The red and blue balls are labeled "u" (for "up" quark) and the green one is labeled "d" (for "down" quark).
A proton is composed of two up quarks, one down quark, and the gluons that mediate the forces "binding" them together. The color assignment of individual quarks is arbitrary, but all three colors must be present; red, blue and green are used as an analogy to the primary colors that together produce a white color.
Composition elementary particle
Statistics fermionic
Generation 1st, 2nd, 3rd
Interactions strong, weak, electromagnetic, gravitation
Symbol q
Antiparticle antiquark (q)
Theorized
Murray Gell-Mann (1964)
George Zweig (1964)
Discovered SLAC (c.1968)
Types 6 (up, down, strange, charm, bottom, and top)
A quark (/ˈkwɔːrk, ˈkwɑːrk/ ⓘ) is a type of elementary particle and a fundamental constituent of matter. Quarks combine to form composite particles called hadrons, the most stable of which are protons and neutrons, the components of atomic nuclei.[1] All commonly observable matter is composed of up quarks, down quarks and electrons. Owing to a phenomenon known as color confinement, quarks are never found in isolation; they can be found only within hadrons, which include baryons (such as protons and neutrons) and mesons, or in quarkgluon plasmas.[2][3][nb 1] For this reason, much of what is known about quarks has been drawn from observations of hadrons.
Quarks have various intrinsic properties, including electric charge, mass, color charge, and spin. They are the only elementary particles in the Standard Model of particle physics to experience all four fundamental interactions, also known as fundamental forces (electromagnetism, gravitation, strong interaction, and weak interaction), as well as the only known particles whose electric charges are not integer multiples of the elementary charge.
There are six types, known as flavors, of quarks: up, down, charm, strange, top, and bottom.[4] Up and down quarks have the lowest masses of all quarks. The heavier quarks rapidly change into up and down quarks through a process of particle decay: the transformation from a higher mass state to a lower mass state. Because of this, up and down quarks are generally stable and the most common in the universe, whereas strange, charm, bottom, and top quarks can only be produced in high energy collisions (such as those involving cosmic rays and in particle accelerators). For every quark flavor there is a corresponding type of antiparticle, known as an antiquark, that differs from the quark only in that some of its properties (such as the electric charge) have equal magnitude but opposite sign.
The quark model was independently proposed by physicists Murray Gell-Mann and George Zweig in 1964.[5] Quarks were introduced as parts of an ordering scheme for hadrons, and there was little evidence for their physical existence until deep inelastic scattering experiments at the Stanford Linear Accelerator Center in 1968.[6][7] Accelerator program experiments have provided evidence for all six flavors. The top quark, first observed at Fermilab in 1995, was the last to be discovered.[5]
Classification
See also: Standard Model
A four-by-four table of particles. Columns are three generations of matter (fermions) and one of forces (bosons). In the first three columns, two rows contain quarks and two leptons. The top two rows' columns contain up (u) and down (d) quarks, charm (c) and strange (s) quarks, top (t) and bottom (b) quarks, and photon (γ) and gluon (g), respectively. The bottom two rows' columns contain electron neutrino (ν sub e) and electron (e), muon neutrino (ν sub μ) and muon (μ), and tau neutrino (ν sub τ) and tau (τ), and Z sup 0 and W sup ± weak force. Mass, charge, and spin are listed for each particle.
Six of the particles in the Standard Model are quarks (shown in purple). Each of the first three columns forms a generation of matter.
The Standard Model is the theoretical framework describing all the known elementary particles. This model contains six flavors of quarks (q), named up (u), down (d), strange (s), charm (c), bottom (b), and top (t).[4] Antiparticles of quarks are called antiquarks, and are denoted by a bar over the symbol for the corresponding quark, such as u for an up antiquark. As with antimatter in general, antiquarks have the same mass, mean lifetime, and spin as their respective quarks, but the electric charge and other charges have the opposite sign.[8]
Quarks are spin-
1
/
2
particles, which means they are fermions according to the spinstatistics theorem. They are subject to the Pauli exclusion principle, which states that no two identical fermions can simultaneously occupy the same quantum state. This is in contrast to bosons (particles with integer spin), of which any number can be in the same state.[9] Unlike leptons, quarks possess color charge, which causes them to engage in the strong interaction. The resulting attraction between different quarks causes the formation of composite particles known as hadrons (see § Strong interaction and color charge below).
The quarks that determine the quantum numbers of hadrons are called valence quarks; apart from these, any hadron may contain an indefinite number of virtual "sea" quarks, antiquarks, and gluons, which do not influence its quantum numbers.[10] There are two families of hadrons: baryons, with three valence quarks, and mesons, with a valence quark and an antiquark.[11] The most common baryons are the proton and the neutron, the building blocks of the atomic nucleus.[12] A great number of hadrons are known (see list of baryons and list of mesons), most of them differentiated by their quark content and the properties these constituent quarks confer. The existence of "exotic" hadrons with more valence quarks, such as tetraquarks (qqqq) and pentaquarks (qqqqq), was conjectured from the beginnings of the quark model[13] but not discovered until the early 21st century.[14][15][16][17]
Elementary fermions are grouped into three generations, each comprising two leptons and two quarks. The first generation includes up and down quarks, the second strange and charm quarks, and the third bottom and top quarks. All searches for a fourth generation of quarks and other elementary fermions have failed,[18][19] and there is strong indirect evidence that no more than three generations exist.[nb 2][20][21][22] Particles in higher generations generally have greater mass and less stability, causing them to decay into lower-generation particles by means of weak interactions. Only first-generation (up and down) quarks occur commonly in nature. Heavier quarks can only be created in high-energy collisions (such as in those involving cosmic rays), and decay quickly; however, they are thought to have been present during the first fractions of a second after the Big Bang, when the universe was in an extremely hot and dense phase (the quark epoch). Studies of heavier quarks are conducted in artificially created conditions, such as in particle accelerators.[23]
Having electric charge, mass, color charge, and flavor, quarks are the only known elementary particles that engage in all four fundamental interactions of contemporary physics: electromagnetism, gravitation, strong interaction, and weak interaction.[12] Gravitation is too weak to be relevant to individual particle interactions except at extremes of energy (Planck energy) and distance scales (Planck distance). However, since no successful quantum theory of gravity exists, gravitation is not described by the Standard Model.
See the table of properties below for a more complete overview of the six quark flavors' properties.
History
Murray Gell-Mann (2007)
George Zweig (2015)
The quark model was independently proposed by physicists Murray Gell-Mann[24] and George Zweig[25][26] in 1964.[5] The proposal came shortly after Gell-Mann's 1961 formulation of a particle classification system known as the Eightfold Way or, in more technical terms, SU(3) flavor symmetry, streamlining its structure.[27] Physicist Yuval Ne'eman had independently developed a scheme similar to the Eightfold Way in the same year.[28][29] An early attempt at constituent organization was available in the Sakata model.
At the time of the quark theory's inception, the "particle zoo" included a multitude of hadrons, among other particles. Gell-Mann and Zweig posited that they were not elementary particles, but were instead composed of combinations of quarks and antiquarks. Their model involved three flavors of quarks, up, down, and strange, to which they ascribed properties such as spin and electric charge.[24][25][26] The initial reaction of the physics community to the proposal was mixed. There was particular contention about whether the quark was a physical entity or a mere abstraction used to explain concepts that were not fully understood at the time.[30]
In less than a year, extensions to the Gell-MannZweig model were proposed. Sheldon Glashow and James Bjorken predicted the existence of a fourth flavor of quark, which they called charm. The addition was proposed because it allowed for a better description of the weak interaction (the mechanism that allows quarks to decay), equalized the number of known quarks with the number of known leptons, and implied a mass formula that correctly reproduced the masses of the known mesons.[31]
Deep inelastic scattering experiments conducted in 1968 at the Stanford Linear Accelerator Center (SLAC) and published on October 20, 1969, showed that the proton contained much smaller, point-like objects and was therefore not an elementary particle.[6][7][32] Physicists were reluctant to firmly identify these objects with quarks at the time, instead calling them "partons" a term coined by Richard Feynman.[33][34][35] The objects that were observed at SLAC would later be identified as up and down quarks as the other flavors were discovered.[36] Nevertheless, "parton" remains in use as a collective term for the constituents of hadrons (quarks, antiquarks, and gluons). Richard Taylor, Henry Kendall and Jerome Friedman received the 1990 Nobel Prize in physics for their work at SLAC.
Photo of bubble chamber tracks next to diagram of same tracks. A neutrino (unseen in photo) enters from below and collides with a proton, producing a negatively charged muon, three positively charged pions, and one negatively charged pion, as well as a neutral lambda baryon (unseen in photograph). The lambda baryon then decays into a proton and a negative pion, producing a "V" pattern.
Photograph of the event that led to the discovery of the Σ++
c baryon, at the Brookhaven National Laboratory in 1974
The strange quark's existence was indirectly validated by SLAC's scattering experiments: not only was it a necessary component of Gell-Mann and Zweig's three-quark model, but it provided an explanation for the kaon (K) and pion (π) hadrons discovered in cosmic rays in 1947.[37]
In a 1970 paper, Glashow, John Iliopoulos and Luciano Maiani presented the GIM mechanism (named from their initials) to explain the experimental non-observation of flavor-changing neutral currents. This theoretical model required the existence of the as-yet undiscovered charm quark.[38][39] The number of supposed quark flavors grew to the current six in 1973, when Makoto Kobayashi and Toshihide Maskawa noted that the experimental observation of CP violation[nb 3][40] could be explained if there were another pair of quarks.
Charm quarks were produced almost simultaneously by two teams in November 1974 (see November Revolution) one at SLAC under Burton Richter, and one at Brookhaven National Laboratory under Samuel Ting. The charm quarks were observed bound with charm antiquarks in mesons. The two parties had assigned the discovered meson two different symbols, J and ψ; thus, it became formally known as the J/ψ meson. The discovery finally convinced the physics community of the quark model's validity.[35]
In the following years a number of suggestions appeared for extending the quark model to six quarks. Of these, the 1975 paper by Haim Harari[41] was the first to coin the terms top and bottom for the additional quarks.[42]
In 1977, the bottom quark was observed by a team at Fermilab led by Leon Lederman.[43][44] This was a strong indicator of the top quark's existence: without the top quark, the bottom quark would have been without a partner. It was not until 1995 that the top quark was finally observed, also by the CDF[45] and DØ[46] teams at Fermilab.[5] It had a mass much larger than expected,[47] almost as large as that of a gold atom.[48]
Etymology
For some time, Gell-Mann was undecided on an actual spelling for the term he intended to coin, until he found the word quark in James Joyce's 1939 book Finnegans Wake:[49]
Three quarks for Muster Mark!
Sure he hasn't got much of a bark
And sure any he has it's all beside the mark.
The word quark is an old English word meaning to croak[50] and the above-quoted lines are about a bird choir mocking king Mark of Cornwall in the legend of Tristan and Iseult.[51] Especially in the German-speaking parts of the world there is a widespread legend, however, that Joyce had taken it from the word Quark,[52] a German word of Slavic origin which denotes a curd cheese,[53] but is also a colloquial term for "trivial nonsense".[54] In the legend it is said that he had heard it on a journey to Germany at a farmers' market in Freiburg.[55][56] Some authors, however, defend a possible German origin of Joyce's word quark.[57] Gell-Mann went into further detail regarding the name of the quark in his 1994 book The Quark and the Jaguar:[58]
In 1963, when I assigned the name "quark" to the fundamental constituents of the nucleon, I had the sound first, without the spelling, which could have been "kwork". Then, in one of my occasional perusals of Finnegans Wake, by James Joyce, I came across the word "quark" in the phrase "Three quarks for Muster Mark". Since "quark" (meaning, for one thing, the cry of the gull) was clearly intended to rhyme with "Mark", as well as "bark" and other such words, I had to find an excuse to pronounce it as "kwork". But the book represents the dream of a publican named Humphrey Chimpden Earwicker. Words in the text are typically drawn from several sources at once, like the "portmanteau" words in Through the Looking-Glass. From time to time, phrases occur in the book that are partially determined by calls for drinks at the bar. I argued, therefore, that perhaps one of the multiple sources of the cry "Three quarks for Muster Mark" might be "Three quarts for Mister Mark", in which case the pronunciation "kwork" would not be totally unjustified. In any case, the number three fitted perfectly the way quarks occur in nature.
Zweig preferred the name ace for the particle he had theorized, but Gell-Mann's terminology came to prominence once the quark model had been commonly accepted.[59]
The quark flavors were given their names for several reasons. The up and down quarks are named after the up and down components of isospin, which they carry.[60] Strange quarks were given their name because they were discovered to be components of the strange particles discovered in cosmic rays years before the quark model was proposed; these particles were deemed "strange" because they had unusually long lifetimes.[61] Glashow, who co-proposed the charm quark with Bjorken, is quoted as saying, "We called our construct the 'charmed quark', for we were fascinated and pleased by the symmetry it brought to the subnuclear world."[62] The names "top" and "bottom", coined by Harari, were chosen because they are "logical partners for up and down quarks".[41][42][61] Alternative names for top and bottom quarks are "truth" and "beauty" respectively,[nb 4] but these names have somewhat fallen out of use.[66] While "truth" never did catch on, accelerator complexes devoted to massive production of bottom quarks are sometimes called "beauty factories".[67]

3984
evals/testdata/special_relativity.txt vendored Normal file

File diff suppressed because it is too large Load Diff

851
evals/testdata/standard_model.txt vendored Normal file
View File

@ -0,0 +1,851 @@
Text from https://en.wikipedia.org/wiki/Standard_Model is licensed under Creative Commons Attribution-ShareAlike 4.0 License; (https://en.wikipedia.org/wiki/Wikipedia:Text_of_the_Creative_Commons_Attribution-ShareAlike_4.0_International_License)
Standard Model
Article
Talk
Read
Edit
View history
Tools
Appearance hide
Text
Small
Standard
Large
Width
Standard
Wide
Color (beta)
Automatic
Light
Dark
From Wikipedia, the free encyclopedia
This article is about a non-mathematical general overview of the Standard Model of particle physics. For a mathematical description, see Mathematical formulation of the Standard Model. For other uses, see Standard model (disambiguation).
Standard Model of particle physics
Elementary particles of the Standard Model
Background
Constituents
Limitations
Scientists
vte
The Standard Model of particle physics is the theory describing three of the four known fundamental forces (electromagnetic, weak and strong interactions excluding gravity) in the universe and classifying all known elementary particles. It was developed in stages throughout the latter half of the 20th century, through the work of many scientists worldwide,[1] with the current formulation being finalized in the mid-1970s upon experimental confirmation of the existence of quarks. Since then, proof of the top quark (1995), the tau neutrino (2000), and the Higgs boson (2012) have added further credence to the Standard Model. In addition, the Standard Model has predicted with great accuracy the various properties of weak neutral currents and the W and Z bosons.
Although the Standard Model is believed to be theoretically self-consistent[note 1] and has demonstrated some success in providing experimental predictions, it leaves some physical phenomena unexplained and so falls short of being a complete theory of fundamental interactions.[3] For example, it does not fully explain why there is more matter than anti-matter, incorporate the full theory of gravitation[4] as described by general relativity, or account for the universe's accelerating expansion as possibly described by dark energy. The model does not contain any viable dark matter particle that possesses all of the required properties deduced from observational cosmology. It also does not incorporate neutrino oscillations and their non-zero masses.
The development of the Standard Model was driven by theoretical and experimental particle physicists alike. The Standard Model is a paradigm of a quantum field theory for theorists, exhibiting a wide range of phenomena, including spontaneous symmetry breaking, anomalies, and non-perturbative behavior. It is used as a basis for building more exotic models that incorporate hypothetical particles, extra dimensions, and elaborate symmetries (such as supersymmetry) to explain experimental results at variance with the Standard Model, such as the existence of dark matter and neutrino oscillations.
Historical background
See also: History of quantum field theory, History of subatomic physics, Julian Schwinger, and John Clive Ward
In 1928, Paul Dirac introduced the Dirac equation, which implied the existence of antimatter.[5] In 1954, Yang Chen-Ning and Robert Mills extended the concept of gauge theory for abelian groups, e.g. quantum electrodynamics, to nonabelian groups to provide an explanation for strong interactions.[6] In 1957, Chien-Shiung Wu demonstrated parity was not conserved in the weak interaction.[7] In 1961, Sheldon Glashow combined the electromagnetic and weak interactions.[8] In 1964, Murray Gell-Mann and George Zweig introduced quarks and that same year Oscar W. Greenberg implicitly introduced color charge of quarks.[9] In 1967 Steven Weinberg[10] and Abdus Salam[11] incorporated the Higgs mechanism[12][13][14] into Glashow's electroweak interaction, giving it its modern form.
In 1970, Sheldon Glashow, John Iliopoulos, and Luciano Maiani introduced the GIM mechanism, predicting the charm quark.[15] In 1973 Gross and Wilczek and Politzer independently discovered that non-Abelian gauge theories, like the color theory of the strong force, have asymptotic freedom.[15] In 1976, Martin Perl discovered the tau lepton at the SLAC.[16][17] In 1977, a team led by Leon Lederman at Fermilab discovered the bottom quark.[18]
The Higgs mechanism is believed to give rise to the masses of all the elementary particles in the Standard Model. This includes the masses of the W and Z bosons, and the masses of the fermions, i.e. the quarks and leptons.
After the neutral weak currents caused by Z boson exchange were discovered at CERN in 1973,[19][20][21][22] the electroweak theory became widely accepted and Glashow, Salam, and Weinberg shared the 1979 Nobel Prize in Physics for discovering it. The W± and Z0 bosons were discovered experimentally in 1983; and the ratio of their masses was found to be as the Standard Model predicted.[23]
The theory of the strong interaction (i.e. quantum chromodynamics, QCD), to which many contributed, acquired its modern form in 197374 when asymptotic freedom was proposed[24][25] (a development that made QCD the main focus of theoretical research)[26] and experiments confirmed that the hadrons were composed of fractionally charged quarks.[27][28]
The term "Standard Model" was introduced by Abraham Pais and Sam Treiman in 1975,[29] with reference to the electroweak theory with four quarks.[30] Steven Weinberg has since claimed priority, explaining that he chose the term Standard Model out of a sense of modesty[a][31][32][better source needed] and used it in 1973 during a talk in Aix-en-Provence in France.[33]
Particle content
The Standard Model includes members of several classes of elementary particles, which in turn can be distinguished by other characteristics, such as color charge.
All particles can be summarized as follows:
vte
Elementary particles
Elementary fermions
Half-integer spin
Obey the FermiDirac statistics
Elementary bosons
Integer spin
Obey the BoseEinstein statistics
Quarks and antiquarks
Spin =
1
/
2
Fractional electric charge
Have color charge
Participate in both strong interactions
and in electroweak interactions
Leptons and antileptons
Spin =
1
/
2
Integer electric charge
No color charge
Participate in Electroweak interactions
Gauge bosons
Spin = 1
Force carriers
Scalar bosons
Spin = 0
Three generations
Up (u),Down (d)
Charm (c),Strange (s)
Top (t),Bottom (b)
Three generations
Electron (e
), [†] Electron neutrino (ν
e)
Muon (μ−
),Muon neutrino (ν
μ)
Tau (τ−
),Tau neutrino (ν
τ)
Three kinds
Photon
(γ; electromagnetic interaction)
W and Z bosons
(W+
, W
, Z0
; weak interaction)
Eight types of gluons
(g; strong interaction)
One kind
Higgs boson (H0
)
Notes:
[†] An anti-electron (e+
) is conventionally called a "positron".
Fermions
The Standard Model includes 12 elementary particles of spin
1
/
2
, known as fermions.[34] Fermions respect the Pauli exclusion principle, meaning that two identical fermions cannot simultaneously occupy the same quantum state in the same atom.[35] Each fermion has a corresponding antiparticle, which are particles that have corresponding properties with the exception of opposite charges.[36] Fermions are classified based on how they interact, which is determined by the charges they carry, into two groups: quarks and leptons. Within each group, pairs of particles that exhibit similar physical behaviors are then grouped into generations (see the table). Each member of a generation has a greater mass than the corresponding particle of generations prior. Thus, there are three generations of quarks and leptons.[37] As first-generation particles do not decay, they comprise all of ordinary (baryonic) matter.[38] Specifically, all atoms consist of electrons orbiting around the atomic nucleus, ultimately constituted of up and down quarks. On the other hand, second- and third-generation charged particles decay with very short half-lives and can only be observed in high-energy environments. Neutrinos of all generations also do not decay, and pervade the universe, but rarely interact with baryonic matter.
There are six quarks: up, down, charm, strange, top, and bottom.[34][37] Quarks carry color charge, and hence interact via the strong interaction. The color confinement phenomenon results in quarks being strongly bound together such that they form color-neutral composite particles called hadrons; quarks cannot individually exist and must always bind with other quarks. Hadrons can contain either a quark-antiquark pair (mesons) or three quarks (baryons).[39] The lightest baryons are the nucleons: the proton and neutron. Quarks also carry electric charge and weak isospin, and thus interact with other fermions through electromagnetism and weak interaction. The six leptons consist of the electron, electron neutrino, muon, muon neutrino, tau, and tau neutrino. The leptons do not carry color charge, and do not respond to strong interaction. The charged leptons carry an electric charge of 1 e, while the three neutrinos carry zero electric charge. Thus, the neutrinos' motions are influenced by only the weak interaction and gravity, making them difficult to observe.
Gauge bosons
Interactions in the Standard Model. All Feynman diagrams in the model are built from combinations of these vertices. q is any quark, g is a gluon, X represents any electrically charged particle, γ is a photon, f is any fermion, m is any particle with mass (with the possible exception of some neutrinos); mB is any massive boson. In diagrams with multiple particle labels separated by '/', one particle label is chosen. However, in those diagrams with particle labels separated by '|', the labels must be chosen in the same left-to-right order. For example, in the four boson electroweak case the valid diagrams are WWWW, WWZZ, WWγγ, WWZγ. The conjugate of each listed vertex (reversing the direction of arrows) is also allowed.[40]
The Standard Model includes 4 kinds of gauge bosons of spin 1,[34] with bosons being quantum particles containing an integer spin. The gauge bosons are defined as force carriers, as they are responsible for mediating the fundamental interactions. The Standard Model explains the four fundamental forces as arising from the interactions, with fermions exchanging virtual force carrier particles, thus mediating the forces. At a macroscopic scale, this manifests as a force.[41] As a result, they do not follow the Pauli exclusion principle that constrains fermions; bosons do not have a theoretical limit on their spatial density. The types of gauge bosons are described below.
Electromagnetism
Photons (γ) mediate the electromagnetic force, responsible for interactions between electrically charged particles. At present, the photon is the only known massless particle, and its interactions with other matter are described by the theory of quantum electrodynamics (QED).
Strong interaction
Gluons (g) mediate the strong interactions, which binds quarks to each other by influencing the color charge, with the interactions being described in the theory of quantum chromodynamics (QCD). For theoretical convenience, gluons are presumed to have no mass. There are eight distinct gluons, by color charge, with each being denoted through a color-and-anticolor combination (e.g. redantigreen).[note 2] As gluons have an effective color charge, they also interact amongst themselves.
Weak interaction
The W+
, W
, and Z0
gauge bosons mediate weak interactions between fermions; the W±
are responsible for radioactive decay, and Z0
deflect neutrinos traveling through solid matter. They have large intrinsic mass, with the Z0
having a little more mass than the W±
, and approximately the same mass as an entire atom of zirconium. Strangely, weak interactions involving the W±
only ever act on left-handed particles and right-handed antiparticles respectively, whereas interactions with the Z0
involve both left- and right-handed particles and antiparticles.[note 3][note 4]
Gravitation
In the Standard Model gravitation is currently only approximately explained, and then only for relatively low-strength gravitational fields, as the hypothetical mediating particle graviton has been proposed and described, but never observed.[43] This is due to the incompatibility of quantum mechanics and Einstein's theory of general relativity, regarded as being the best explanation for gravity. In general relativity, gravity is explained as being the geometric curving of spacetime.[44]
The Feynman diagram calculations, which are a graphical representation of the perturbation theory approximation, invoke "force mediating particles", and when applied to analyze high-energy scattering experiments are in reasonable agreement with the data. However, perturbation theory (and with it the concept of a "force-mediating particle") fails in other situations. These include low-energy quantum chromodynamics, bound states, and solitons. The interactions between all the particles described by the Standard Model are summarized by the diagrams on the right of this section.
Higgs boson
Main article: Higgs boson
The Higgs particle is a massive scalar elementary particle theorized by Peter Higgs (and others) in 1964, when he showed that Goldstone's 1962 theorem (generic continuous symmetry, which is spontaneously broken) provides a third polarization of a massive vector field. Hence, Goldstone's original scalar doublet, the massive spin-zero particle, was proposed as the Higgs boson, and is a key building block in the Standard Model.[45] It has no intrinsic spin, and for that reason is classified as a boson with spin-0.[34]
The Higgs boson plays a unique role in the Standard Model, by explaining why the other elementary particles, except the photon and gluon, are massive. In particular, the Higgs boson explains why the photon has no mass, while the W and Z bosons are very heavy. Elementary-particle masses and the differences between electromagnetism (mediated by the photon) and the weak force (mediated by the W and Z bosons) are critical to many aspects of the structure of microscopic (and hence macroscopic) matter. In electroweak theory, the Higgs boson generates the masses of the leptons (electron, muon, and tau) and quarks. As the Higgs boson is massive, it must interact with itself.
Because the Higgs boson is a very massive particle and also decays almost immediately when created, only a very high-energy particle accelerator can observe and record it. Experiments to confirm and determine the nature of the Higgs boson using the Large Hadron Collider (LHC) at CERN began in early 2010 and were performed at Fermilab's Tevatron until its closure in late 2011. Mathematical consistency of the Standard Model requires that any mechanism capable of generating the masses of elementary particles must become visible[clarification needed] at energies above 1.4 TeV;[46] therefore, the LHC (designed to collide two 7 TeV proton beams) was built to answer the question of whether the Higgs boson actually exists.[47]
On 4 July 2012, two of the experiments at the LHC (ATLAS and CMS) both reported independently that they had found a new particle with a mass of about 125 GeV/c2 (about 133 proton masses, on the order of 1025 kg), which is "consistent with the Higgs boson".[48][49] On 13 March 2013, it was confirmed to be the searched-for Higgs boson.[50][51]
Theoretical aspects
Main article: Mathematical formulation of the Standard Model
Construction of the Standard Model Lagrangian
Parameters of the Standard Model
Technically, quantum field theory provides the mathematical framework for the Standard Model, in which a Lagrangian controls the dynamics and kinematics of the theory. Each kind of particle is described in terms of a dynamical field that pervades space-time.[52] The construction of the Standard Model proceeds following the modern method of constructing most field theories: by first postulating a set of symmetries of the system, and then by writing down the most general renormalizable Lagrangian from its particle (field) content that observes these symmetries.
The global Poincaré symmetry is postulated for all relativistic quantum field theories. It consists of the familiar translational symmetry, rotational symmetry and the inertial reference frame invariance central to the theory of special relativity. The local SU(3) × SU(2) × U(1) gauge symmetry is an internal symmetry that essentially defines the Standard Model. Roughly, the three factors of the gauge symmetry give rise to the three fundamental interactions. The fields fall into different representations of the various symmetry groups of the Standard Model (see table). Upon writing the most general Lagrangian, one finds that the dynamics depends on 19 parameters, whose numerical values are established by experiment. The parameters are summarized in the table (made visible by clicking "show") above.
Quantum chromodynamics sector
Main article: Quantum chromodynamics
The quantum chromodynamics (QCD) sector defines the interactions between quarks and gluons, which is a YangMills gauge theory with SU(3) symmetry, generated by
T
a
=
λ
a
/
2
{\displaystyle T^{a}=\lambda ^{a}/2}. Since leptons do not interact with gluons, they are not affected by this sector. The Dirac Lagrangian of the quarks coupled to the gluon fields is given by
L
QCD
=
ψ
¯
i
γ
μ
D
μ
ψ
1
4
G
μ
ν
a
G
a
μ
ν
,
{\displaystyle {\mathcal {L}}_{\text{QCD}}={\overline {\psi }}i\gamma ^{\mu }D_{\mu }\psi -{\frac {1}{4}}G_{\mu \nu }^{a}G_{a}^{\mu \nu },}where
ψ
{\displaystyle \psi } is a three component column vector of Dirac spinors, each element of which refers to a quark field with a specific color charge (i.e. red, blue, and green) and summation over flavor (i.e. up, down, strange, etc.) is implied.
The gauge covariant derivative of QCD is defined by
D
μ
μ
i
g
s
1
2
λ
a
G
μ
a
{\displaystyle D_{\mu }\equiv \partial _{\mu }-ig_{\text{s}}{\frac {1}{2}}\lambda ^{a}G_{\mu }^{a}}, where
γμ are the Dirac matrices,
Ga
μ is the 8-component (
a
=
1
,
2
,
,
8
{\displaystyle a=1,2,\dots ,8}) SU(3) gauge field,
λa
are the 3 × 3 Gell-Mann matrices, generators of the SU(3) color group,
Ga
μν represents the gluon field strength tensor, and
gs is the strong coupling constant.
The QCD Lagrangian is invariant under local SU(3) gauge transformations; i.e., transformations of the form
ψ
ψ
=
U
ψ
{\displaystyle \psi \rightarrow \psi '=U\psi }, where
U
=
e
i
g
s
λ
a
ϕ
a
(
x
)
{\displaystyle U=e^{-ig_{\text{s}}\lambda ^{a}\phi ^{a}(x)}} is 3 × 3 unitary matrix with determinant 1, making it a member of the group SU(3), and
ϕ
a
(
x
)
{\displaystyle \phi ^{a}(x)} is an arbitrary function of spacetime.
Electroweak sector
Main article: Electroweak interaction
The electroweak sector is a YangMills gauge theory with the symmetry group U(1) × SU(2)L,
L
EW
=
Q
¯
L
j
i
γ
μ
D
μ
Q
L
j
+
u
¯
R
j
i
γ
μ
D
μ
u
R
j
+
d
¯
R
j
i
γ
μ
D
μ
d
R
j
+
¯
L
j
i
γ
μ
D
μ
L
j
+
e
¯
R
j
i
γ
μ
D
μ
e
R
j
1
4
W
a
μ
ν
W
μ
ν
a
1
4
B
μ
ν
B
μ
ν
,
{\displaystyle {\mathcal {L}}_{\text{EW}}={\overline {Q}}_{{\text{L}}j}i\gamma ^{\mu }D_{\mu }Q_{{\text{L}}j}+{\overline {u}}_{{\text{R}}j}i\gamma ^{\mu }D_{\mu }u_{{\text{R}}j}+{\overline {d}}_{{\text{R}}j}i\gamma ^{\mu }D_{\mu }d_{{\text{R}}j}+{\overline {\ell }}_{{\text{L}}j}i\gamma ^{\mu }D_{\mu }\ell _{{\text{L}}j}+{\overline {e}}_{{\text{R}}j}i\gamma ^{\mu }D_{\mu }e_{{\text{R}}j}-{\tfrac {1}{4}}W_{a}^{\mu \nu }W_{\mu \nu }^{a}-{\tfrac {1}{4}}B^{\mu \nu }B_{\mu \nu },}where the subscript
j
{\displaystyle j} sums over the three generations of fermions;
Q
L
,
u
R
{\displaystyle Q_{\text{L}},u_{\text{R}}}, and
d
R
{\displaystyle d_{\text{R}}} are the left-handed doublet, right-handed singlet up type, and right handed singlet down type quark fields; and
L
{\displaystyle \ell _{\text{L}}} and
e
R
{\displaystyle e_{\text{R}}} are the left-handed doublet and right-handed singlet lepton fields.
The electroweak gauge covariant derivative is defined as
D
μ
μ
i
g
1
2
Y
W
B
μ
i
g
1
2
τ
L
W
μ
{\displaystyle D_{\mu }\equiv \partial _{\mu }-ig'{\tfrac {1}{2}}Y_{\text{W}}B_{\mu }-ig{\tfrac {1}{2}}{\vec {\tau }}_{\text{L}}{\vec {W}}_{\mu }}, where
Bμ is the U(1) gauge field,
YW is the weak hypercharge the generator of the U(1) group,
W→μ is the 3-component SU(2) gauge field,
τ
L are the Pauli matrices infinitesimal generators of the SU(2) group with subscript L to indicate that they only act on left-chiral fermions,
g' and g are the U(1) and SU(2) coupling constants respectively,
W
a
μ
ν
{\displaystyle W^{a\mu \nu }} (
a
=
1
,
2
,
3
{\displaystyle a=1,2,3}) and
B
μ
ν
{\displaystyle B^{\mu \nu }} are the field strength tensors for the weak isospin and weak hypercharge fields.
Notice that the addition of fermion mass terms into the electroweak Lagrangian is forbidden, since terms of the form
m
ψ
¯
ψ
{\displaystyle m{\overline {\psi }}\psi } do not respect U(1) × SU(2)L gauge invariance. Neither is it possible to add explicit mass terms for the U(1) and SU(2) gauge fields. The Higgs mechanism is responsible for the generation of the gauge boson masses, and the fermion masses result from Yukawa-type interactions with the Higgs field.
Higgs sector
Main article: Higgs mechanism
In the Standard Model, the Higgs field is an SU(2)L doublet of complex scalar fields with four degrees of freedom:
φ
=
(
φ
+
φ
0
)
=
1
2
(
φ
1
+
i
φ
2
φ
3
+
i
φ
4
)
,
{\displaystyle \varphi ={\begin{pmatrix}\varphi ^{+}\\\varphi ^{0}\end{pmatrix}}={\frac {1}{\sqrt {2}}}{\begin{pmatrix}\varphi _{1}+i\varphi _{2}\\\varphi _{3}+i\varphi _{4}\end{pmatrix}},}where the superscripts + and 0 indicate the electric charge
Q
{\displaystyle Q} of the components. The weak hypercharge
Y
W
{\displaystyle Y_{\text{W}}} of both components is 1. Before symmetry breaking, the Higgs Lagrangian is
L
H
=
(
D
μ
φ
)
(
D
μ
φ
)
V
(
φ
)
,
{\displaystyle {\mathcal {L}}_{\text{H}}=\left(D_{\mu }\varphi \right)^{\dagger }\left(D^{\mu }\varphi \right)-V(\varphi ),}where
D
μ
{\displaystyle D_{\mu }} is the electroweak gauge covariant derivative defined above and
V
(
φ
)
{\displaystyle V(\varphi )} is the potential of the Higgs field. The square of the covariant derivative leads to three and four point interactions between the electroweak gauge fields
W
μ
a
{\displaystyle W_{\mu }^{a}} and
B
μ
{\displaystyle B_{\mu }} and the scalar field
φ
{\displaystyle \varphi }. The scalar potential is given by
V
(
φ
)
=
μ
2
φ
φ
+
λ
(
φ
φ
)
2
,
{\displaystyle V(\varphi )=-\mu ^{2}\varphi ^{\dagger }\varphi +\lambda \left(\varphi ^{\dagger }\varphi \right)^{2},}where
μ
2
>
0
{\displaystyle \mu ^{2}>0}, so that
φ
{\displaystyle \varphi } acquires a non-zero Vacuum expectation value, which generates masses for the Electroweak gauge fields (the Higgs mechanism), and
λ
>
0
{\displaystyle \lambda >0}, so that the potential is bounded from below. The quartic term describes self-interactions of the scalar field
φ
{\displaystyle \varphi }.
The minimum of the potential is degenerate with an infinite number of equivalent ground state solutions, which occurs when
φ
φ
=
μ
2
2
λ
{\displaystyle \varphi ^{\dagger }\varphi ={\tfrac {\mu ^{2}}{2\lambda }}}. It is possible to perform a gauge transformation on
φ
{\displaystyle \varphi } such that the ground state is transformed to a basis where
φ
1
=
φ
2
=
φ
4
=
0
{\displaystyle \varphi _{1}=\varphi _{2}=\varphi _{4}=0} and
φ
3
=
μ
λ
v
{\displaystyle \varphi _{3}={\tfrac {\mu }{\sqrt {\lambda }}}\equiv v}. This breaks the symmetry of the ground state. The expectation value of
φ
{\displaystyle \varphi } now becomes
φ
=
1
2
(
0
v
)
,
{\displaystyle \langle \varphi \rangle ={\frac {1}{\sqrt {2}}}{\begin{pmatrix}0\\v\end{pmatrix}},}where
v
{\displaystyle v} has units of mass and sets the scale of electroweak physics. This is the only dimensional parameter of the Standard Model and has a measured value of ~246 GeV/c2.
After symmetry breaking, the masses of the W and Z are given by
m
W
=
1
2
g
v
{\displaystyle m_{\text{W}}={\frac {1}{2}}gv} and
m
Z
=
1
2
g
2
+
g
2
v
{\displaystyle m_{\text{Z}}={\frac {1}{2}}{\sqrt {g^{2}+g'^{2}}}v}, which can be viewed as predictions of the theory. The photon remains massless. The mass of the Higgs boson is
m
H
=
2
μ
2
=
2
λ
v
{\displaystyle m_{\text{H}}={\sqrt {2\mu ^{2}}}={\sqrt {2\lambda }}v}. Since
μ
{\displaystyle \mu } and
λ
{\displaystyle \lambda } are free parameters, the Higgs's mass could not be predicted beforehand and had to be determined experimentally.
Yukawa sector
The Yukawa interaction terms are:
L
Yukawa
=
(
Y
u
)
m
n
(
Q
¯
L
)
m
φ
~
(
u
R
)
n
+
(
Y
d
)
m
n
(
Q
¯
L
)
m
φ
(
d
R
)
n
+
(
Y
e
)
m
n
(
¯
L
)
m
φ
(
e
R
)
n
+
h
.
c
.
{\displaystyle {\mathcal {L}}_{\text{Yukawa}}=(Y_{\text{u}})_{mn}({\bar {Q}}_{\text{L}})_{m}{\tilde {\varphi }}(u_{\text{R}})_{n}+(Y_{\text{d}})_{mn}({\bar {Q}}_{\text{L}})_{m}\varphi (d_{\text{R}})_{n}+(Y_{\text{e}})_{mn}({\bar {\ell }}_{\text{L}})_{m}{\varphi }(e_{\text{R}})_{n}+\mathrm {h.c.} }where
Y
u
{\displaystyle Y_{\text{u}}},
Y
d
{\displaystyle Y_{\text{d}}}, and
Y
e
{\displaystyle Y_{\text{e}}} are 3×3 matrices of Yukawa couplings, with the mn term giving the coupling of the generations m and n, and h.c. means Hermitian conjugate of preceding terms. The fields
Q
L
{\displaystyle Q_{\text{L}}} and
L
{\displaystyle \ell _{\text{L}}} are left-handed quark and lepton doublets. Likewise,
u
R
,
d
R
{\displaystyle u_{\text{R}},d_{\text{R}}} and
e
R
{\displaystyle e_{\text{R}}} are right-handed up-type quark, down-type quark, and lepton singlets. Finally
φ
{\displaystyle \varphi } is the Higgs doublet and
φ
~
=
i
τ
2
φ
{\displaystyle {\tilde {\varphi }}=i\tau _{2}\varphi ^{*}} is its charge conjugate state.
The Yukawa terms are invariant under the SU(2)L × U(1)Y gauge symmetry of the Standard Model and generate masses for all fermions after spontaneous symmetry breaking.
Fundamental interactions
Main article: Fundamental interaction
The Standard Model describes three of the four fundamental interactions in nature; only gravity remains unexplained. In the Standard Model, such an interaction is described as an exchange of bosons between the objects affected, such as a photon for the electromagnetic force and a gluon for the strong interaction. Those particles are called force carriers or messenger particles.[53]
The four fundamental interactions of nature[54]
Property/Interaction Gravitation Electroweak Strong
Weak Electromagnetic Fundamental Residual
Mediating particles Not yet observed
(Graviton hypothesised) W+, W and Z0 γ (photon) Gluons π, ρ and ω mesons
Affected particles All particles W+, W: Left-handed fermions; Z0: All fermions Electrically charged Quarks, gluons Hadrons
Acts on Stressenergy tensor Flavor Electric charge Color charge
Bound states formed Planets, stars, galaxies, galaxy groups —N/a Atoms, molecules Hadrons Atomic nuclei
Strength at the scale of quarks
(relative to electromagnetism) 1041 (predicted) 104 1 60 Not applicable
to quarks
Strength at the scale of
protons/neutrons
(relative to electromagnetism) 1036 (predicted) 107 1 Not applicable
to hadrons 20
icon
This section does not cite any sources. Please help improve this section by adding citations to reliable sources. Unsourced material may be challenged and removed. (June 2021) (Learn how and when to remove this message)
Gravity
See also: Quantum gravity and Gravity
Fundamental Interactions of the Standard Model including the hypothetical graviton
Despite being perhaps the most familiar fundamental interaction, gravity is not described by the Standard Model, due to contradictions that arise when combining general relativity, the modern theory of gravity, and quantum mechanics.[55][56] However, gravity is so weak at microscopic scales, that it is essentially unmeasurable. The graviton is postulated to be the mediating particle, but has not yet been proved to exist.[57]
Electromagnetism
See also: Electromagnetism and Quantum electrodynamics
Electromagnetism is the only long-range force in the Standard Model. It is mediated by photons and couples to electric charge.[58] Electromagnetism is responsible for a wide range of phenomena including atomic electron shell structure, chemical bonds, electric circuits and electronics. Electromagnetic interactions in the Standard Model are described by quantum electrodynamics.
Weak interaction
See also: Weak interaction and Electroweak interaction
The weak interaction is responsible for various forms of particle decay, such as beta decay. It is weak and short-range, due to the fact that the weak mediating particles, W and Z bosons, have mass. W bosons have electric charge and mediate interactions that change the particle type (referred to as flavor) and charge. Interactions mediated by W bosons are charged current interactions. Z bosons are neutral and mediate neutral current interactions, which do not change particle flavor. Thus Z bosons are similar to the photon, aside from them being massive and interacting with the neutrino. The weak interaction is also the only interaction to violate parity and CP. Parity violation is maximal for charged current interactions, since the W boson interacts exclusively with left-handed fermions and right-handed antifermions.
In the Standard Model, the weak force is understood in terms of the electroweak theory, which states that the weak and electromagnetic interactions become united into a single electroweak interaction at high energies.
Strong interaction
See also: Strong interaction, Nuclear force, and Quantum chromodynamics
The strong interaction is responsible for hadronic and nuclear binding. It is mediated by gluons, which couple to color charge. Since gluons themselves have color charge, the strong force exhibits confinement and asymptotic freedom. Confinement means that only color-neutral particles can exist in isolation, therefore quarks can only exist in hadrons and never in isolation, at low energies. Asymptotic freedom means that the strong force becomes weaker, as the energy scale increases. The strong force overpowers the electrostatic repulsion of protons and quarks in nuclei and hadrons respectively, at their respective scales.
While quarks are bound in hadrons by the fundamental strong interaction, which is mediated by gluons, nucleons are bound by an emergent phenomenon termed the residual strong force or nuclear force. This interaction is mediated by mesons, such as the pion. The color charges inside the nucleon cancel out, meaning most of the gluon and quark fields cancel out outside of the nucleon. However, some residue is "leaked", which appears as the exchange of virtual mesons, which result in an effective attractive force between nucleons. The (fundamental) strong interaction is described by quantum chromodynamics, which is a component of the Standard Model.
Tests and predictions
The Standard Model predicted the existence of the W and Z bosons, gluon, top quark and charm quark, and predicted many of their properties before these particles were observed. The predictions were experimentally confirmed with good precision.[59]
The Standard Model also predicted the existence of the Higgs boson, which was found in 2012 at the Large Hadron Collider, the final fundamental particle predicted by the Standard Model to be experimentally confirmed.[60]
Challenges
See also: Physics beyond the Standard Model
Unsolved problem in physics
What gives rise to the Standard Model of particle physics?
Why do particle masses and coupling constants have the values that we measure?
Why are there three generations of particles?
Why is there more matter than antimatter in the universe?
Where does dark matter fit into the model? Does it even consist of one or more new particles?
More unsolved problems in physics
Self-consistency of the Standard Model (currently formulated as a non-abelian gauge theory quantized through path-integrals) has not been mathematically proved. While regularized versions useful for approximate computations (for example lattice gauge theory) exist, it is not known whether they converge (in the sense of S-matrix elements) in the limit that the regulator is removed. A key question related to the consistency is the YangMills existence and mass gap problem.
Experiments indicate that neutrinos have mass, which the classic Standard Model did not allow.[61] To accommodate this finding, the classic Standard Model can be modified to include neutrino mass, although it is not obvious exactly how this should be done.
If one insists on using only Standard Model particles, this can be achieved by adding a non-renormalizable interaction of leptons with the Higgs boson.[62] On a fundamental level, such an interaction emerges in the seesaw mechanism where heavy right-handed neutrinos are added to the theory. This is natural in the left-right symmetric extension of the Standard Model[63][64] and in certain grand unified theories.[65] As long as new physics appears below or around 1014 GeV, the neutrino masses can be of the right order of magnitude.
Theoretical and experimental research has attempted to extend the Standard Model into a unified field theory or a theory of everything, a complete theory explaining all physical phenomena including constants. Inadequacies of the Standard Model that motivate such research include:
The model does not explain gravitation, although physical confirmation of a theoretical particle known as a graviton would account for it to a degree. Though it addresses strong and electroweak interactions, the Standard Model does not consistently explain the canonical theory of gravitation, general relativity, in terms of quantum field theory. The reason for this is, among other things, that quantum field theories of gravity generally break down before reaching the Planck scale. As a consequence, we have no reliable theory for the very early universe.
Some physicists consider it to be ad hoc and inelegant, requiring 19 numerical constants whose values are unrelated and arbitrary.[66] Although the Standard Model, as it now stands, can explain why neutrinos have masses, the specifics of neutrino mass are still unclear. It is believed that explaining neutrino mass will require an additional 7 or 8 constants, which are also arbitrary parameters.[67]
The Higgs mechanism gives rise to the hierarchy problem if some new physics (coupled to the Higgs) is present at high energy scales. In these cases, in order for the weak scale to be much smaller than the Planck scale, severe fine tuning of the parameters is required; there are, however, other scenarios that include quantum gravity in which such fine tuning can be avoided.[68]
The model is inconsistent with the emerging Lambda-CDM model of cosmology. Contentions include the absence of an explanation in the Standard Model of particle physics for the observed amount of cold dark matter (CDM) and its contributions to dark energy, which are many orders of magnitude too large. It is also difficult to accommodate the observed predominance of matter over antimatter (matter/antimatter asymmetry). The isotropy and homogeneity of the visible universe over large distances seems to require a mechanism like cosmic inflation, which would also constitute an extension of the Standard Model.
Currently, no proposed theory of everything has been widely accepted or verified.
See also
YangMills theory
Fundamental interaction:
Quantum electrodynamics
Strong interaction: Color charge, Quantum chromodynamics, Quark model
Weak interaction: Electroweak interaction, Fermi's interaction, Weak hypercharge, Weak isospin
Gauge theory: Introduction to gauge theory
Generation
Higgs mechanism: Higgs boson, Alternatives to the Standard Higgs Model
Lagrangian (field theory)
Open questions: CP violation, Neutrino masses, QCD matter, Quantum triviality
Quantum field theory
Standard Model: Mathematical formulation of, Physics beyond the Standard Model
Electron electric dipole moment
Notes
There are mathematical issues regarding quantum field theories still under debate (see e.g. Landau pole), but the predictions extracted from the Standard Model by current methods applicable to current experiments are all self-consistent.[2]
Although nine coloranticolor combinations mathematically exist, gluons form color octet particles. As one color-symmetric combination is linear and forms a color singlet particles, there are eight possible gluons.[42]
The W±
carries an electric charge of +1 and right-handed spin, or 1and left-handed spin. The W±
separately couple to other particles through photons and the electromagnetic interaction.
The electrically neutral Z0
boson interacts with both left-handed and right-handed regular particles and antiparticles, although with different strengths for each combination of left- and right- and of regular particles and antiparticles (see weak charge). Photons and these three gauge bosons are grouped together as a hypothetically unified, single the electroweak interaction.
A model is a representation of reality, whereas a theory is an explanation of reality; this Wikipedia article and some of the literature refers to the Standard Model as a theory.

View File

@ -0,0 +1,179 @@
// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cstring>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// Gemma2: gemma2-2b-it
namespace gcpp {
namespace {
static const char* kQuestions =
"From the above information, please answer the following questions: "
"What did Marcia find in the sand? "
"What is Albert's preferred holiday activity? "
"How long did it take to dig out the object from the sand? "
"What is Marcia's preferred holiday activity? "
"What made the castle turrets look like daleks? "
"Which people first proposed the quark model of hadrons, and when?";
// All phrases in kAnswers must appear in the response in the order given for
// the test to pass.
static const char* kAnswers[] = {
"a ship's anchor", "a dark forest", "an hour",
"enormous sand", "castles", "limpet shells",
"Murray Gell-Mann", "George Zweig", "1964"};
std::string LoadPromptFile(const std::string& filename) {
// If the filename is empty, return an empty string.
if (filename.empty()) {
return "";
}
std::string path = testing::SrcDir() +
"evals/testdata/"
+ filename;
return ReadFileToString(Path(path));
}
std::string BuildPrompt(const std::vector<std::string>& files,
const std::string& suffix) {
std::string prompt;
for (const std::string& file : files) {
prompt += LoadPromptFile(file);
}
prompt += suffix;
return prompt;
}
class GemmaTest : public ::testing::Test {
public:
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}
static void DeleteEnv() { delete s_env; }
protected:
std::string GemmaReply(const std::string& input,
AttentionImpl attention_mode) {
HWY_ASSERT(s_env); // must have called InitEnv()
s_env->SetMaxGeneratedTokens(256);
s_env->MutableConfig().attention_impl = attention_mode;
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 1;
// Always use turn structure (WrapAndTokenize).
auto response = s_env->QueryModel(input);
return response.response.substr(response.response_start_pos);
}
// Checks that the response contains the expected answer substrings int the
// expected order. Testing against a few keywords is more robust than checking
// the whole string.
void TestExpectations(const std::string& response) {
fprintf(stderr, "Response: '%s'\n", response.c_str());
size_t pos = 0;
for (const char* answer : kAnswers) {
auto found = response.find(answer, pos);
EXPECT_NE(found, std::string::npos)
<< "Response does not contain " << answer;
if (found != std::string::npos) {
pos = found + strlen(answer);
}
}
s_env->PrintProfileResults();
}
// Shared state. Requires argc/argv, so construct in main via InitEnv.
// Note that the style guide forbids non-local static variables with dtors.
static GemmaEnv* s_env;
};
GemmaEnv* GemmaTest::s_env = nullptr;
// Tests whether Gemma can find the right answer in varying levels of
// background information, ranging from the bare facts to outright distraction.
TEST_F(GemmaTest, WheatFromChaff) {
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
fprintf(stderr, "Warmup, mode %s\n", GetAttentionImplName(modes[0]).c_str());
auto prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions);
auto response = GemmaReply(prompt, modes[0]);
TestExpectations(response);
for (const AttentionImpl mode : modes) {
const std::string mode_name = GetAttentionImplName(mode);
fprintf(stderr, "\nTesting quark_1 prompt, mode %s\n", mode_name.c_str());
prompt = BuildPrompt({"holiday_story.txt", "quark_1.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
fprintf(stderr, "\nTesting quark_2 prompt, mode %s\n", mode_name.c_str());
prompt = BuildPrompt({"holiday_story.txt", "quark_2.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
fprintf(stderr, "\nTesting standard_model prompt, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt(
{"holiday_story.txt", "quark_2.txt", "standard_model.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
if (s_env->MutableKVCache().SeqLen() > 38000) {
fprintf(stderr, "\nTesting special_relativity, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt(
{"holiday_story.txt", "quark_2.txt", "special_relativity.txt"},
kQuestions);
} else {
fprintf(stderr, "\nSkipping special_relativity, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions);
}
response = GemmaReply(prompt, mode);
TestExpectations(response);
}
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();
return ret;
}

View File

@ -24,6 +24,7 @@
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/flash_structs.h"
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "gemma/tensor_stats.h"
@ -52,10 +53,13 @@ struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
const Allocator& allocator,
size_t max_workers, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
: rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
// `vocab_size == 0` means it is for Vit part, VitAttention
// is still MHA and does not use an external KV cache.
layer_config.heads),
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
@ -76,11 +80,19 @@ struct AttentionActivations {
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
allocator)),
// att is only valid for AttentionImpl::kOld.
att(MatFactory(
"att", batch_size,
layer_config.heads *
(runtime_config.attention_impl == AttentionImpl::kOld ? seq_len
: 1),
allocator)),
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
layer_config.heads * layer_config.qkv_dim,
allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
@ -102,6 +114,11 @@ struct AttentionActivations {
}
return;
}
// This is a guess at the maximum number of params we might need to avoid
// reallocations. The actual number of params is determined by the number of
// query tiles, which is not known here.
flash_params.reserve(batch_size * layer_config.heads);
split_flash_params.reserve(batch_size * layer_config.heads);
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
@ -125,6 +142,7 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_out_reps.OverrideRows(batch_size * rep_factor);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
@ -132,6 +150,15 @@ struct AttentionActivations {
// `inv_timescale*` are not batched.
}
// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
// the number of query rows and 1/8th of that.
std::vector<FlashAttentionParams> flash_params;
// Parameters for flash attention, split by k-position. May be significantly
// larger than flash_params in decode mode, when the number of query rows is
// small.
std::vector<FlashAttentionParams> split_flash_params;
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
@ -143,6 +170,7 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> att_out_reps; // attention output for each thread.
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads
@ -151,19 +179,27 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Replication factor to help evenly share work over threads.
static constexpr size_t kThreadReplicationFactor = 4;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
AttentionActivationsPtrs(
const ModelConfig& config, size_t seq_len,
std::vector<FlashAttentionParams>& flash_params,
std::vector<FlashAttentionParams>& split_flash_params)
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
activations.split_flash_params) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
@ -173,6 +209,7 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_out_reps = activations.att_out_reps;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
@ -203,6 +240,9 @@ struct AttentionActivationsPtrs {
}
const ModelConfig& config;
// Parameters for flash attention.
std::vector<FlashAttentionParams>& flash_params;
std::vector<FlashAttentionParams>& split_flash_params;
// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
@ -228,6 +268,7 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
MatPtrT<float> att_out_reps;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
@ -282,7 +323,8 @@ struct Activations {
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config, ctx.allocator, row_ptrs),
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

View File

@ -49,6 +49,39 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Returns the number of floats per vector (aka NF).
size_t FloatsPerVector() {
using DF = hn::ScalableTag<float>;
const DF df;
return hn::Lanes(df);
}
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}
// Transposes a single row of the kv cache into the k-cache and v-cache.
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
// This is inefficient, as the writes are scattered over cache lines, but it
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
}
}
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
@ -280,6 +313,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
@ -299,6 +337,26 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
// Note that k_cache and v_cache are different shapes.
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
@ -319,13 +377,17 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
/*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
// This is inefficient, as multiple threads are writing the same K
// cache line, but the input is generated by a matmul, so it is
// difficult to change, and it probably isn't significant.
TransposeKVCacheRow(kv, k, v, qkv_dim);
});
}
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
MatMulEnv& env, AttentionImpl attention_impl, int flags) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
const LayerConfig& layer_config = layer.layer_config;
@ -335,15 +397,16 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
(void)layer_config; // only used in HWY_DASSERT
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kAttentionUseOld) {
if (attention_impl == AttentionImpl::kOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
activations, qbatch, env.ctx);
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
AttentionActivations::kThreadReplicationFactor,
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx);
env.ctx, attention_impl);
}
SumHeads(layer, activations, env);
}

View File

@ -31,6 +31,13 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
\
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \
@ -53,7 +60,8 @@ namespace gcpp {
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
MatMulEnv& env, AttentionImpl attention_impl, \
int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -1,8 +1,10 @@
#include <cstddef>
#include <cstdlib>
#include <cstring> // strcmp
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -105,7 +107,8 @@ struct TestAttentionState {
tokens(num_tokens),
attention_storage_(model_state.config, model_state.layer_config,
batch_size, num_tokens, runtime_config,
state.ctx.allocator, row_ptrs_),
state.ctx.pools.MaxWorkers(), state.ctx.allocator,
row_ptrs_),
attention(model_state.config, num_tokens, attention_storage_) {
for (size_t i = 0; i < qbatch_size; ++i) {
kv_caches.emplace_back(model_state.config, inference_args,
@ -143,6 +146,7 @@ struct TestAttentionState {
};
double GetTolerance() {
if (IsBF16<KV_t>()) return 1e-2;
const char* target_name = hwy::TargetName(HWY_TARGET);
if (strncmp(target_name, "AVX2", 4) == 0) {
return 2e-2;
@ -155,6 +159,57 @@ double GetTolerance() {
}
}
// Comparison function for computations that used BF16, whether the result is
// stored in BF16 or F32.
// Compare with absolute tolerance for values with small magnitudes.
// Compare with relative tolerance for values with larger magnitudes.
template <typename T>
bool CompareArraySimilarBF16(const T* expected, const T* actual, size_t count,
const char* target_name, const char* filename,
int line) {
constexpr double kTolerance = 3e-2;
for (size_t i = 0; i < count; ++i) {
const double exp = hwy::ConvertScalarTo<double>(expected[i]);
const double act = hwy::ConvertScalarTo<double>(actual[i]);
const double l1 = std::abs(act - exp);
// Cannot divide, so check absolute error.
if (std::abs(exp) <= 1.0) {
if (l1 > kTolerance) {
std::string array_values = hwy::detail::FormatMismatchedArrays(
expected, actual, count, kTolerance);
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E l1 %E tol %E%s\n",
target_name, filename, line, "BF16", i, count, exp, act, l1,
kTolerance, array_values.c_str());
return false;
}
} else { // relative
const double rel = l1 / exp;
if (rel > kTolerance) {
std::string array_values = hwy::detail::FormatMismatchedArrays(
expected, actual, count, kTolerance);
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E rel %E tol %E%s\n",
target_name, filename, line, "BF16", i, count, exp, act, rel,
kTolerance, array_values.c_str());
return false;
}
}
}
return true;
}
template <typename T>
bool CompareArraySimilar(const T* expected, const T* actual, size_t count,
const char* target_name, const char* filename,
int line) {
if constexpr (IsBF16<KV_t>()) {
return CompareArraySimilarBF16(expected, actual, count, target_name,
filename, line);
} else {
return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(),
target_name, filename, line);
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareAttSumsWithGolden(
const AttentionActivationsPtrs& attention,
@ -170,9 +225,9 @@ void CompareAttSumsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
}
EXPECT_TRUE(hwy::CompareArraySimilar(
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(),
kDims, hwy::TargetName(HWY_TARGET),
__FILE__, __LINE__))
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
}
}
@ -200,19 +255,20 @@ void CompareKVCacheWithGolden(
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const float* cache_row =
const BF16* cache_row =
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
for (size_t j = 0; j < kDims; ++j) {
actual_k_row[j] = cache_row[kv_offset + j];
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
actual_k_row[j] = hwy::ConvertScalarTo<float>(cache_row[kv_offset + j]);
actual_v_row[j] =
hwy::ConvertScalarTo<float>(cache_row[kv_offset + qkv_dim + j]);
}
EXPECT_TRUE(hwy::CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
EXPECT_TRUE(hwy::CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
@ -238,8 +294,8 @@ void CompareQVecsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_q_row[j] = q_row[head_offset + j];
}
EXPECT_TRUE(hwy::CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
<< " q_head=" << q_head;
@ -267,42 +323,42 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875,
128, 27.25, -161, 19.125, -58, 97.5},
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125,
6.4625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.62165625, -46.5, -152, -2.9375, -81}},
{{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75,
-46, -22, -19.375, -16.125, -148, 20.875},
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
{-47, -19.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
10.9375, -52.5, 23.25, 75}},
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25,
-34.75, 18, -52, 100, -186, -75.5, 50.75},
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
{7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
39.25, 65, 47.25, -89.5, -34.25, 137}},
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
32.5, 53.75, 109, 4.62375, 57.5, -20.5, 132},
{143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5,
195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}},
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
{137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5,
88.5, 48.5, -46.25, -36.75, 101.5}},
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5,
125, -53.5, -14.625, 262},
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
{28.075, 6.64375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
81, -27.25, -3.03125, 111, -167, 59, 176}},
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
{40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5,
58.25, 173, -53.5, 25.125, 4.8125}},
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
-17.125, 15.25, 143, -27, 103, 101},
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
8.125, -99.5, 13.6875, -11.6875, 33}},
8.55, -99.5, 14.6875, -11.6875, 33}},
};
// Layer 0, *K*V Head 0
@ -538,7 +594,7 @@ void RunAttentionTest(AttentionImpl attention_impl) {
GemmaAttention(attention_state.tokens.size(), 0, model_state.layer,
attention_state.attention, *attention_state.qbatch, state.env,
AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16));
attention_impl, /*flags=*/0);
CompareAttSumsWithGolden(attention_state.attention, kGoldenAttSums);
CompareKVCacheWithGolden(model_state.config,

View File

@ -712,9 +712,21 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
}
}
// Keep in sync with enum class AttentionImpl.
const char* kAttentionImplNames[] = {
"old", "flash",
"unknown" // keep last
};
std::string GetAttentionImplName(AttentionImpl impl) {
return kAttentionImplNames[static_cast<size_t>(impl)];
}
AttentionImpl GetAttentionImpl(const std::string& impl) {
if (impl == "old") return AttentionImpl::kOld;
if (impl == "flash") return AttentionImpl::kFlash;
if (impl == GetAttentionImplName(AttentionImpl::kOld))
return AttentionImpl::kOld;
if (impl == GetAttentionImplName(AttentionImpl::kFlash))
return AttentionImpl::kFlash;
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
return AttentionImpl::kOld;
}

View File

@ -81,11 +81,12 @@ static inline bool EnumValid(LayerAttentionType type) {
}
enum class AttentionImpl {
kOld,
kFlash,
kOld, // Previous Attention implementation
kFlash, // Flash Attention (default)
kSentinel,
};
std::string GetAttentionImplName(AttentionImpl impl);
AttentionImpl GetAttentionImpl(const std::string& impl);
/*

File diff suppressed because it is too large Load Diff

View File

@ -44,22 +44,13 @@ namespace gcpp {
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
ThreadingContext& ctx, AttentionImpl attention_impl); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -62,16 +62,17 @@ namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
void SetMat(const size_t offset, MatPtrT<float>& mat) {
template <typename T>
void SetMat(const size_t offset, MatPtrT<T>& mat) {
const size_t kOuter = mat.Extents().rows;
const size_t kInner = mat.Extents().cols;
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
float* row = mat.Row(i);
T* row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
row[j] = hwy::ConvertScalarTo<T>(
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale)));
}
}
}
@ -94,14 +95,15 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
if (rel_abs_delta > 0.0f) {
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
}
EXPECT_LT(rel_abs_delta, 1e-5)
EXPECT_LT(rel_abs_delta, 1e-3)
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
<< c << "]=" << b_row[c];
}
}
}
void TestFlashAttention(size_t target_parallelism) {
void TestFlashAttention(size_t target_parallelism,
AttentionImpl attention_impl) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
constexpr size_t kOuter = 1024;
@ -112,7 +114,9 @@ void TestFlashAttention(size_t target_parallelism) {
const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
InferenceArgs inference_args;
inference_args.attention_impl = "flash";
// attention_impl must be old in order for the att intermediate to be
// allocated for the old attention.
inference_args.attention_impl = "old";
RuntimeConfig runtime_config;
inference_args.CopyTo(runtime_config);
KVCache kv_cache(config, inference_args, ctx.allocator);
@ -129,7 +133,8 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, runtime_config, ctx.allocator,
kOuter, runtime_config,
ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
@ -140,7 +145,10 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
auto& kvc = qbatch.KV(0).kv_cache;
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
@ -151,6 +159,17 @@ void TestFlashAttention(size_t target_parallelism) {
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
SetMat(h + layer_config.heads, k);
SetMat(h + layer_config.heads * 2, v);
for (size_t p = 0; p < tokens.size(); ++p) {
KV_t* HWY_RESTRICT k_src = k.Row(p);
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * 2;
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * kFloatsPerTile;
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
}
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
@ -165,18 +184,19 @@ void TestFlashAttention(size_t target_parallelism) {
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n",
target_parallelism, kNF, kVTileSize,
GetAttentionImplName(attention_impl).c_str());
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
attention, qbatch, ctx);
attention, qbatch, ctx, attention_impl);
AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
}
void TestAttention() {
TestFlashAttention(8192);
TestFlashAttention(2048);
TestFlashAttention(256);
TestFlashAttention(8192, AttentionImpl::kFlash);
TestFlashAttention(2048, AttentionImpl::kFlash);
TestFlashAttention(256, AttentionImpl::kFlash);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -2,11 +2,19 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <stdint.h>
#include <limits>
namespace gcpp {
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 4 for 4 Q-rows.
static constexpr size_t k4xNFVTileSize = 4;
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 8 for 8 Q-rows.
static constexpr size_t k8xNFVTileSize = 8;
// State for computing softmax in a streaming ("online") manner,
// avoiding large intermediate values by subtracting the running maximum.
// For a sequence x_1, ..., x_n:
@ -20,10 +28,44 @@ struct OnlineSoftmaxState {
float d = 0.0f;
};
static constexpr size_t kVTileSize4 = 4;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
OnlineSoftmaxState row_states[k8xNFVTileSize];
};
// Parameters for a strip of tiles of flash attention. For processing a strip
// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF
// k-positions. The total width of the strip might cover the entire sequence,
// or a part of it, depending on whether the strip has been split.
struct FlashAttentionParams {
// Vertical tile size gives the number used in the k8xNFVTileSize arrays.
// It is the number of Q rows in the tile.
uint32_t v_tile_size = 0;
// min start position across all rows in the tile determines the
// mask used for the tile.
uint32_t min_start_pos = std::numeric_limits<uint32_t>::max();
// max last position across all rows in the tile determines the mask
// used for the tile.
uint32_t max_last_pos = 0;
// Index into the qbatch.KV is the same for each row in the tile.
uint32_t qi_index;
// Index into the kv_cache is the same for each row in the tile.
uint32_t kv_offset;
// In the original task, the index to the split tasks of the first split task.
uint32_t split_index = 0;
// The index of the split for running split attention.
uint32_t i_of_n = 0;
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[k8xNFVTileSize];
// Offsets into att_out for each row in the tile.
uint32_t out_offsets[k8xNFVTileSize];
// Start k-positions for each row in the tile.
uint32_t start_pos[k8xNFVTileSize];
// Last k-positions for each row in the tile. Inclusive.
uint32_t last_pos[k8xNFVTileSize];
// Row index to att_out.
uint32_t tq_idx[k8xNFVTileSize];
// Flash attention state for the tile.
Tile4FlashState end_state;
};
} // namespace gcpp

View File

@ -83,9 +83,8 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention.
GemmaAttention(
num_tokens, layer_idx, layer, activations.attention, qbatch, env,
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env, activations.attention_impl, /*flags=*/0);
}
}
@ -595,6 +594,9 @@ static void GenerateT(const ModelConfig& config,
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
// generate phases for profiling.
env.ctx.profiler.PrintResults();
hwy::BitSet4096<> non_eos; // indexed by qi

View File

@ -43,6 +43,17 @@ static size_t CappedSeqLen(const ModelConfig& config,
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
// WARNING: the rows and cols of k_cache and v_cache will be modified
// before use!
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
// cols will be increased by 2xkFloatsPerVector on first use. This is to
// avoid making KVCache another class that has to be duplicated for each
// machine architecture, since kFloatsPerVector is architecture dependent.
// The change is shape is safe only if the padding is kPacked.
k_cache("k", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
MatPadding::kPacked),
v_cache("v", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
MatPadding::kPacked),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
@ -55,6 +66,8 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
CopyMat(k_cache, copy.k_cache);
CopyMat(v_cache, copy.v_cache);
return copy;
}

View File

@ -30,7 +30,7 @@
namespace gcpp {
using KV_t = float;
using KV_t = BF16;
// A non-owning view of a KVCache.
struct KVCachePtr {
@ -38,6 +38,8 @@ struct KVCachePtr {
size_t SeqLen() const;
MatPtrT<KV_t> kv_cache;
MatPtrT<KV_t> k_cache;
MatPtrT<KV_t> v_cache;
};
struct KVCache {
@ -52,10 +54,33 @@ struct KVCache {
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
// The format of k_cache indicates that there are pairs of values from
// qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence,
// in groups of qkv_dim/2 elements in groups of kv_heads elements.
// This enables sequential loading of the data when filling 2 vectors with
// NF sequence elements of pairs of BF16 qkv values. The next vector then
// continues reading the rest of qkv.
// [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2]
MatStorageT<KV_t> k_cache;
// v_cache is formatted to allow sequential access to V during scaling and
// update of att_out.
// Originally [seq_len, layers * kv_heads * qkv_dim]
// v_cache is transposed to:
// [layers, kv_heads, seq_len, qkv_dim], reshaped to:
// [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF]
// then transposed to:
// [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF]
// and finally packed in a 2D MatStorageT as:
// [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF]
// This allows sequential reads of 2NF registers each of 2NF BF16 values,
// repeatedly until all of qkv_dim is read.
MatStorageT<KV_t> v_cache;
KVCachePtr ToPtr() {
return KVCachePtr{
.kv_cache = kv_cache,
.k_cache = k_cache,
.v_cache = v_cache,
};
}

View File

@ -614,267 +614,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
});
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale));
sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale));
sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale));
sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale));
sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale));
sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale));
sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale));
sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8);
sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9);
sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10);
sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11);
sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12);
sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13);
sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14);
sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2,
VF& sum3) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
}
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd4(df, x1, c1, out0, out1, out2, out3);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd4(df, x2, c2, out0, out1, out2, out3);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd4(df, x3, c3, out0, out1, out2, out3);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd4(df, x4, c4, out0, out1, out2, out3);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd4(df, x5, c5, out0, out1, out2, out3);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd4(df, x6, c6, out0, out1, out2, out3);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd4(df, x7, c7, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
}
HWY_DASSERT(size == i);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
const VF c1, const VF c2, const VF c3,
@ -887,240 +626,134 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0,
const VF c1, const VF c2,
const VF c3, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
sum3);
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
sum3);
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
sum3);
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
sum3);
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4(
DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c,
const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a,
VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF c1 = hn::Set(df, *c++);
VF c2 = hn::Set(df, *c++);
VF c3 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a);
MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b);
}
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
sum3);
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
sum3);
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
sum3);
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
sum3);
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
sum3);
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
sum2, sum3);
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
sum2, sum3);
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
sum2, sum3);
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
sum2, sum3);
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
sum2, sum3);
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
sum2, sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows
// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
// The depth (size) must be a multiple of 2NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const VF c10, const VF c11, const VF c20, const VF c21, const VF c30,
const VF c31, const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos,
size_t num_lanes, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[8 * kMaxNF];
hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem);
hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF);
size_t i = 0;
while (i + NF <= size) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 8) {
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 16) {
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
out3);
}
}
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
i += NF;
while (i + NF * 2 <= size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::Load(df, out + i + out_offsets[0]);
out1a = hn::Load(df, out + i + out_offsets[1]);
out2a = hn::Load(df, out + i + out_offsets[2]);
out3a = hn::Load(df, out + i + out_offsets[3]);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
VF scale3 = hn::Set(df, scales[3]);
out0a = hn::Mul(out0a, scale0);
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out1a, df, out + i + out_offsets[1]);
hn::Store(out2a, df, out + i + out_offsets[2]);
hn::Store(out3a, df, out + i + out_offsets[3]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
HWY_DASSERT(size == i);
}
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
// corresponding values in c0 and adds them to the NF rows of out.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df,
const BF16* HWY_RESTRICT v,
const float* HWY_RESTRICT c,
const size_t num_lanes,
VF& sum0a, VF& sum0b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
sum0a = hn::MulAdd(v0a, c0, sum0a);
sum0b = hn::MulAdd(v0b, c0, sum0b);
}
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos, size_t num_lanes,
float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[2 * kMaxNF];
hn::Store(c00, df, c_mem);
hn::Store(c01, df, c_mem + NF);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
while (i + NF * 2 <= size) {
VF out0a, out0b;
out0a = hn::Load(df, out + i + out_offsets[0]);
VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out0b = hn::Mul(out0b, scale0);
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
HWY_DASSERT(size == i);
}

View File

@ -202,6 +202,17 @@ class MatPtr : public IFields {
override_rows_ = static_cast<uint32_t>(rows);
}
// Changes the number of rows and columns without reallocating the memory.
// Increases cols by factor and reduces rows by factor.
// The rows must be divisible by factor and the matrix must be packed.
void ReshapePackedRowsToCols(size_t factor) {
HWY_ASSERT(IsPacked());
HWY_ASSERT(private_rows_ % factor == 0);
private_rows_ /= factor;
cols_ *= factor;
stride_ *= factor;
}
// Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; }

View File

@ -106,7 +106,8 @@ template <typename T>
void FillMatPtrT(MatPtrT<T>& mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
mat.Row(i)[j] =
hwy::ConvertScalarTo<T>(hwy::Unpredictable1() * 0.01f * (i + j + 1));
}
}
}

View File

@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) {
return "FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionSingleFlashAttention:
return "FlashAttention.SingleFlashAttention";
case Zones::kFlashAttentionTileFlashAttention:
return "FlashAttention.TileFlashAttention";
case Zones::kFlashAttentionTileFlashAttention1:
return "FlashAttention.TileFlashAttention1";
case Zones::kFlashAttentionTileFlashAttention4:
return "FlashAttention.TileFlashAttention4";
case Zones::kFlashAttentionTransposeQ:
return "FlashAttention.TransposeQ";
case Zones::kFlashAttentionTileFlashAttention8:
return "FlashAttention.TileFlashAttention8";
case Zones::kFlashAttentionCombineSplit:
return "FlashAttention.CombineSplit";
case Zones::kGenActivation:
return "Gen.Activation";
case Zones::kGenActivationFused:

View File

@ -14,10 +14,10 @@ enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention1,
kFlashAttentionTileFlashAttention4,
kFlashAttentionTransposeQ,
kFlashAttentionTileFlashAttention8,
kFlashAttentionCombineSplit,
kGenActivation,
kGenActivationFused,
kGenAttention,