mirror of https://github.com/google/gemma.cpp.git
Added flash attention, with both a single-q function, and a register-tiled function.
The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine. PiperOrigin-RevId: 804913784
This commit is contained in:
parent
24b1760f03
commit
f10ac41a20
23
BUILD.bazel
23
BUILD.bazel
|
|
@ -117,6 +117,27 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "flash_attention_test",
|
||||||
|
srcs = ["gemma/flash_attention_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":configs",
|
||||||
|
":gemma_args",
|
||||||
|
":gemma_lib",
|
||||||
|
":kv_cache",
|
||||||
|
":mat",
|
||||||
|
":matmul",
|
||||||
|
":ops",
|
||||||
|
":threading_context",
|
||||||
|
":weights",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//compression:compress",
|
||||||
|
"//compression:types",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "threading_test",
|
name = "threading_test",
|
||||||
srcs = ["util/threading_test.cc"],
|
srcs = ["util/threading_test.cc"],
|
||||||
|
|
@ -526,12 +547,14 @@ cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemma/attention.cc",
|
"gemma/attention.cc",
|
||||||
|
"gemma/flash_attention.cc",
|
||||||
"gemma/gemma.cc",
|
"gemma/gemma.cc",
|
||||||
"gemma/vit.cc",
|
"gemma/vit.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gemma/activations.h",
|
"gemma/activations.h",
|
||||||
"gemma/attention.h",
|
"gemma/attention.h",
|
||||||
|
"gemma/flash_attention.h",
|
||||||
"gemma/gemma.h",
|
"gemma/gemma.h",
|
||||||
"gemma/vit.h",
|
"gemma/vit.h",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,8 @@ set(SOURCES
|
||||||
gemma/attention.h
|
gemma/attention.h
|
||||||
gemma/configs.cc
|
gemma/configs.cc
|
||||||
gemma/configs.h
|
gemma/configs.h
|
||||||
|
gemma/flash_attention.cc
|
||||||
|
gemma/flash_attention.h
|
||||||
gemma/gemma_args.h
|
gemma/gemma_args.h
|
||||||
gemma/gemma-inl.h
|
gemma/gemma-inl.h
|
||||||
gemma/gemma.cc
|
gemma/gemma.cc
|
||||||
|
|
@ -216,6 +218,7 @@ set(GEMMA_TEST_FILES
|
||||||
compression/nuq_test.cc
|
compression/nuq_test.cc
|
||||||
compression/sfp_test.cc
|
compression/sfp_test.cc
|
||||||
evals/gemma_test.cc
|
evals/gemma_test.cc
|
||||||
|
gemma/flash_attention_test.cc
|
||||||
gemma/tensor_info_test.cc
|
gemma/tensor_info_test.cc
|
||||||
io/blob_store_test.cc
|
io/blob_store_test.cc
|
||||||
io/fields_test.cc
|
io/fields_test.cc
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,11 @@ struct AttentionActivations {
|
||||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||||
: layer_config.heads * layer_config.qkv_dim,
|
: layer_config.heads * layer_config.qkv_dim,
|
||||||
allocator)),
|
allocator)),
|
||||||
|
q_T(MatFactory("q_T", layer_config.qkv_dim,
|
||||||
|
config.vocab_size == 0
|
||||||
|
? batch_size * layer_config.heads * 3
|
||||||
|
: batch_size * layer_config.heads,
|
||||||
|
allocator)),
|
||||||
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||||
config.model_dim, allocator)),
|
config.model_dim, allocator)),
|
||||||
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
||||||
|
|
@ -90,11 +94,13 @@ struct AttentionActivations {
|
||||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||||
// fill them in each MatMul call.
|
// fill them in each MatMul call.
|
||||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
|
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
q.OverrideRows(batch_size);
|
q.OverrideRows(batch_size);
|
||||||
|
q_T.OverrideRows(batch_size);
|
||||||
|
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
|
|
@ -105,6 +111,7 @@ struct AttentionActivations {
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
MatStorageT<float> q; // query
|
||||||
|
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
||||||
|
|
||||||
MatStorageT<float> pre_att_rms_out;
|
MatStorageT<float> pre_att_rms_out;
|
||||||
MatStorageT<float> att; // attention vector
|
MatStorageT<float> att; // attention vector
|
||||||
|
|
|
||||||
|
|
@ -41,12 +41,16 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
|
#include "gemma/flash_attention.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use.
|
||||||
|
constexpr int kUseOldAttention = 2;
|
||||||
|
|
||||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
|
|
@ -71,11 +75,11 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
hwy::Profiler& p, const size_t worker,
|
hwy::Profiler& p, const size_t worker,
|
||||||
const size_t pos, const float mul = 1.0f) {
|
const size_t pos, const float mul) {
|
||||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
|
|
@ -165,8 +169,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
|
|
||||||
// The attention window usually starts at 0 unless `pos` is larger than
|
// The attention window usually starts at 0 unless `pos` is larger than
|
||||||
// the attention window size, then it is `pos` - window_size + 1.
|
// the attention window size, then it is `pos` - window_size + 1.
|
||||||
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
|
||||||
size_t layer_idx) {
|
|
||||||
const size_t att_window_size = config.attention_window_sizes[layer_idx];
|
const size_t att_window_size = config.attention_window_sizes[layer_idx];
|
||||||
return pos - HWY_MIN(att_window_size - 1, pos);
|
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||||
}
|
}
|
||||||
|
|
@ -314,7 +317,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||||
env.ctx.profiler, worker, pos);
|
env.ctx.profiler, worker, pos, /*mul=*/1.0f);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
@ -354,8 +357,12 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
(void)layer_config; // only used in HWY_DASSERT
|
(void)layer_config; // only used in HWY_DASSERT
|
||||||
|
|
||||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||||
|
if (flags & kUseOldAttention) {
|
||||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||||
env.ctx);
|
env.ctx);
|
||||||
|
} else {
|
||||||
|
FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx);
|
||||||
|
}
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,14 @@ namespace gcpp {
|
||||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
|
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||||
|
const LayerWeightsPtrs& layer, \
|
||||||
|
const AttentionActivations& activations, \
|
||||||
|
hwy::Profiler& p, size_t worker, size_t pos, \
|
||||||
|
float mul); \
|
||||||
|
\
|
||||||
|
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
|
||||||
|
\
|
||||||
void SingleDotSoftmaxWeightedSum( \
|
void SingleDotSoftmaxWeightedSum( \
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,510 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
#include "util/threading_context.h"
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/configs.h" // kMaxQKVDim
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/weights.h"
|
||||||
|
#include "util/threading.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "gemma/flash_attention.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
#include "gemma/attention.h"
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
// Transposes q into q_t.
|
||||||
|
// Both are 4D tensors stuffed into a 2-D MatPtrT.
|
||||||
|
// q has shape [batch, qbatch][head, qkv_dim].
|
||||||
|
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
|
||||||
|
// possible consecutive elements have the same KV.
|
||||||
|
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
|
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||||
|
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
|
||||||
|
const size_t num_heads = q.Cols() / q_t.Rows();
|
||||||
|
const size_t batch_size = q.Rows() / qbatch_size;
|
||||||
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
|
float* HWY_RESTRICT qt_row = q_t.Row(task);
|
||||||
|
for (size_t qi = 0; qi < qbatch_size; ++qi)
|
||||||
|
for (size_t h = 0; h < num_heads; ++h) {
|
||||||
|
for (size_t b = 0; b < batch_size; ++b) {
|
||||||
|
qt_row[(qi * num_heads + h) * batch_size + b] =
|
||||||
|
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
{
|
||||||
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
|
HierarchicalParallelFor(q_t.Rows(), ctx.pools, func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updates q in place for RMSNorm and positional encoding.
|
||||||
|
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
|
MatPtrT<KV_t>& q, const size_t layer_idx,
|
||||||
|
const LayerWeightsPtrs& layer,
|
||||||
|
const AttentionActivations& activations,
|
||||||
|
ThreadingContext& ctx) {
|
||||||
|
static const auto zone =
|
||||||
|
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
|
||||||
|
const float query_scale = activations.query_scale;
|
||||||
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||||
|
const size_t tq_idx = qbatch.Size() * task + qi;
|
||||||
|
// Find the token position in the query and calculate
|
||||||
|
// the range of cache positions to attend to.
|
||||||
|
const size_t pos = qbatch.Pos(qi) + task;
|
||||||
|
float* HWY_RESTRICT q_row =
|
||||||
|
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||||
|
// Apply rope and scaling to Q.
|
||||||
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
|
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||||
|
RMSNormInplace(weights_t->PackedScale1(), q_row,
|
||||||
|
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
||||||
|
worker, pos, query_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
{
|
||||||
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
|
HierarchicalParallelFor(num_tokens, ctx.pools, func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates the complete attention outputs for a single row of q.
|
||||||
|
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
|
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||||
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
|
const LayerWeightsPtrs& layer,
|
||||||
|
const AttentionActivations& activations,
|
||||||
|
float* HWY_RESTRICT att_out, hwy::Profiler& p,
|
||||||
|
const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||||
|
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
||||||
|
float d = 1.0f;
|
||||||
|
// This is just a copy of the first token.
|
||||||
|
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||||
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
|
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||||
|
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
||||||
|
if (activations.config.att_cap > 0.0f) {
|
||||||
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||||
|
x = activations.config.att_cap *
|
||||||
|
std::tanh(x / activations.config.att_cap);
|
||||||
|
}
|
||||||
|
float m_new = std::max(m, x);
|
||||||
|
float scale = d * std::exp(m - m_new);
|
||||||
|
x = std::exp(x - m_new);
|
||||||
|
m = m_new;
|
||||||
|
d = scale + x;
|
||||||
|
float one_over_d = 1.0f / d;
|
||||||
|
x *= one_over_d;
|
||||||
|
scale *= one_over_d;
|
||||||
|
MulByConst(scale, att_out, v.Cols(), p, worker);
|
||||||
|
MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes and returns a single vector of NF Q.K dot products, which represents
|
||||||
|
// the dot products of NF rows of Q for a single K timestep.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
|
const size_t k_pos, const MatPtrT<KV_t>& q,
|
||||||
|
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) {
|
||||||
|
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||||
|
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||||
|
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
|
||||||
|
}
|
||||||
|
return hn::LoadU(df, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an 8xNF tile of Q.K dot products, in single precision.
|
||||||
|
// This is the result of NF rows of Q against 8 K timesteps, with positions
|
||||||
|
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
|
||||||
|
// consecutive elements, and other columns by adding q_stride.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||||
|
const MatPtrT<KV_t>& k, const size_t* k_pos,
|
||||||
|
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
||||||
|
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
||||||
|
VF& sum7) {
|
||||||
|
constexpr size_t kHTileSize = 8;
|
||||||
|
sum0 = hn::Zero(df);
|
||||||
|
sum1 = hn::Zero(df);
|
||||||
|
sum2 = hn::Zero(df);
|
||||||
|
sum3 = hn::Zero(df);
|
||||||
|
sum4 = hn::Zero(df);
|
||||||
|
sum5 = hn::Zero(df);
|
||||||
|
sum6 = hn::Zero(df);
|
||||||
|
sum7 = hn::Zero(df);
|
||||||
|
const float* HWY_RESTRICT k_row[kHTileSize];
|
||||||
|
for (int i = 0; i < kHTileSize; ++i) {
|
||||||
|
k_row[i] = k.Row(k_pos[i]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||||
|
VF q_vec = hn::Load(df, q);
|
||||||
|
VF k_0 = hn::Set(df, k_row[0][i]);
|
||||||
|
sum0 = hn::MulAdd(q_vec, k_0, sum0);
|
||||||
|
VF k_1 = hn::Set(df, k_row[1][i]);
|
||||||
|
sum1 = hn::MulAdd(q_vec, k_1, sum1);
|
||||||
|
VF k_2 = hn::Set(df, k_row[2][i]);
|
||||||
|
sum2 = hn::MulAdd(q_vec, k_2, sum2);
|
||||||
|
VF k_3 = hn::Set(df, k_row[3][i]);
|
||||||
|
sum3 = hn::MulAdd(q_vec, k_3, sum3);
|
||||||
|
VF k_4 = hn::Set(df, k_row[4][i]);
|
||||||
|
sum4 = hn::MulAdd(q_vec, k_4, sum4);
|
||||||
|
VF k_5 = hn::Set(df, k_row[5][i]);
|
||||||
|
sum5 = hn::MulAdd(q_vec, k_5, sum5);
|
||||||
|
VF k_6 = hn::Set(df, k_row[6][i]);
|
||||||
|
sum6 = hn::MulAdd(q_vec, k_6, sum6);
|
||||||
|
VF k_7 = hn::Set(df, k_row[7][i]);
|
||||||
|
sum7 = hn::MulAdd(q_vec, k_7, sum7);
|
||||||
|
q += q_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the element-wise maximum of 8 vectors, in a single vector.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
||||||
|
const VF& x3, const VF& x4, const VF& x5,
|
||||||
|
const VF& x6, const VF& x7) {
|
||||||
|
VF m0 = hn::Max(x0, x1);
|
||||||
|
VF m1 = hn::Max(x2, x3);
|
||||||
|
VF m2 = hn::Max(x4, x5);
|
||||||
|
VF m3 = hn::Max(x6, x7);
|
||||||
|
m0 = hn::Max(m0, m1);
|
||||||
|
m2 = hn::Max(m2, m3);
|
||||||
|
return hn::Max(m0, m2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the element-wise sum of 8 vectors, in a single vector.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
||||||
|
const VF& x3, const VF& x4, const VF& x5,
|
||||||
|
const VF& x6, const VF& x7) {
|
||||||
|
VF sum0 = hn::Add(x0, x1);
|
||||||
|
VF sum1 = hn::Add(x2, x3);
|
||||||
|
VF sum2 = hn::Add(x4, x5);
|
||||||
|
VF sum3 = hn::Add(x6, x7);
|
||||||
|
sum0 = hn::Add(sum0, sum1);
|
||||||
|
sum2 = hn::Add(sum2, sum3);
|
||||||
|
return hn::Add(sum0, sum2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then
|
||||||
|
// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos].
|
||||||
|
void TileFlashAttention(
|
||||||
|
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
|
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||||
|
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
||||||
|
const size_t min_last_pos, const size_t max_last_pos,
|
||||||
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
|
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||||
|
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
|
hwy::Profiler& p, const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
constexpr int kHTileSize = 8;
|
||||||
|
using DF = hn::ScalableTag<float>;
|
||||||
|
const DF df;
|
||||||
|
using VF = hn::Vec<DF>;
|
||||||
|
using DI = hn::ScalableTag<uint32_t>;
|
||||||
|
const DI di;
|
||||||
|
using VI = hn::Vec<DI>;
|
||||||
|
VI lasts = hn::LoadU(di, last_pos);
|
||||||
|
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
|
||||||
|
VF old_d = hn::Zero(df);
|
||||||
|
const float* HWY_RESTRICT qT_row = qT.Row(0);
|
||||||
|
const size_t qT_stride = qT.Stride();
|
||||||
|
size_t position = start_pos;
|
||||||
|
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||||
|
size_t k_pos[kHTileSize];
|
||||||
|
for (size_t i = 0; i < kHTileSize; ++i) {
|
||||||
|
k_pos[i] = activations.div_seq_len.Remainder(position + i);
|
||||||
|
}
|
||||||
|
VF x0, x1, x2, x3, x4, x5, x6, x7;
|
||||||
|
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
|
||||||
|
x4, x5, x6, x7);
|
||||||
|
if (activations.config.att_cap > 0.0f) {
|
||||||
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||||
|
VF cap = hn::Set(df, activations.config.att_cap);
|
||||||
|
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
|
||||||
|
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
|
||||||
|
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
|
||||||
|
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
|
||||||
|
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
|
||||||
|
x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap)));
|
||||||
|
x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap)));
|
||||||
|
x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap)));
|
||||||
|
x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap)));
|
||||||
|
}
|
||||||
|
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||||
|
m = hn::Max(old_m, m);
|
||||||
|
x0 = hn::Exp(df, x0 - m);
|
||||||
|
x1 = hn::Exp(df, x1 - m);
|
||||||
|
x2 = hn::Exp(df, x2 - m);
|
||||||
|
x3 = hn::Exp(df, x3 - m);
|
||||||
|
x4 = hn::Exp(df, x4 - m);
|
||||||
|
x5 = hn::Exp(df, x5 - m);
|
||||||
|
x6 = hn::Exp(df, x6 - m);
|
||||||
|
x7 = hn::Exp(df, x7 - m);
|
||||||
|
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
|
||||||
|
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||||
|
old_d = hn::Add(scale, old_d);
|
||||||
|
old_m = m;
|
||||||
|
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
|
||||||
|
scale = hn::Mul(scale, one_over_d);
|
||||||
|
x0 = hn::Mul(x0, one_over_d);
|
||||||
|
x1 = hn::Mul(x1, one_over_d);
|
||||||
|
x2 = hn::Mul(x2, one_over_d);
|
||||||
|
x3 = hn::Mul(x3, one_over_d);
|
||||||
|
x4 = hn::Mul(x4, one_over_d);
|
||||||
|
x5 = hn::Mul(x5, one_over_d);
|
||||||
|
x6 = hn::Mul(x6, one_over_d);
|
||||||
|
x7 = hn::Mul(x7, one_over_d);
|
||||||
|
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
|
||||||
|
att_out.Row(0), out_offsets, v.Cols(), p, worker);
|
||||||
|
position += kHTileSize;
|
||||||
|
}
|
||||||
|
while (position <= max_last_pos) {
|
||||||
|
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||||
|
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker);
|
||||||
|
if (activations.config.att_cap > 0.0f) {
|
||||||
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
|
||||||
|
VF cap = hn::Set(df, activations.config.att_cap);
|
||||||
|
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
|
||||||
|
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
|
||||||
|
}
|
||||||
|
// Past the last position, x0 doesn't count.
|
||||||
|
auto mask = hn::Gt(hn::Set(di, position), lasts);
|
||||||
|
VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask),
|
||||||
|
std::numeric_limits<float>::max() / 2.0f);
|
||||||
|
x0 = hn::Sub(x0, causal_offset);
|
||||||
|
VF m = hn::Max(old_m, x0);
|
||||||
|
x0 = hn::Exp(df, x0 - m);
|
||||||
|
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
|
||||||
|
old_m = m;
|
||||||
|
old_d = hn::Add(scale, x0);
|
||||||
|
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
|
||||||
|
x0 = hn::Mul(x0, one_over_d);
|
||||||
|
scale = hn::Mul(scale, one_over_d);
|
||||||
|
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
|
||||||
|
v.Cols(), p, worker);
|
||||||
|
++position;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
|
||||||
|
// into a single output O[L,D].
|
||||||
|
// Conventional attention first computes A[L,L] = Q . KT
|
||||||
|
// followed by A = softmax(A) (over invididual rows).
|
||||||
|
// Then A is multiplied by V to get O[L,D].
|
||||||
|
// For each row of O, this takes a read of one row of Q L times, all of K,
|
||||||
|
// 3 write/reads of one row of A, read all of V, an read.write the one row of O
|
||||||
|
// L times. Ignoring the computation for now, and focusing just on memory,
|
||||||
|
// the one row of O takes L(4D+3) reads and L(D+3) writes.
|
||||||
|
// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes.
|
||||||
|
//
|
||||||
|
// Flash attention fuses these operations together, and (where possible)
|
||||||
|
// computes NF rows of the result using 8 accumulator registers and two more to
|
||||||
|
// keep running results. NF is the number of float lanes in a register, being 16
|
||||||
|
// for AVX3. The softmax is converted to streaming form using the
|
||||||
|
// algortihm from:
|
||||||
|
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf.
|
||||||
|
// Q is transposed to Q_T[D,L] to make the dot product computation efficient.
|
||||||
|
// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing
|
||||||
|
// reads of Q by 8 and reads of K by NF. The streaming softmax is computed
|
||||||
|
// entirely in registers, and a further NF registers to accumulate the results
|
||||||
|
// of the product of the softmax and V, reduce the number of reads of V by NF,
|
||||||
|
// and the reads/writes of O by 8.
|
||||||
|
// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8,
|
||||||
|
// which on AVX3 is an overall reduction by about a factor of 10.
|
||||||
|
//
|
||||||
|
// A further complication is that real attention is not as simple as documented
|
||||||
|
// in the paper and above. There are multiple query heads, differing KV, and
|
||||||
|
// different sequence lengths, so a lot of the work in FlashAttention is making
|
||||||
|
// sure that a collection of q rows can use the TileFlashAttention path.
|
||||||
|
void FlashAttention(const size_t num_tokens, const size_t layer_idx,
|
||||||
|
const LayerWeightsPtrs& layer,
|
||||||
|
AttentionActivations& activations, QBatch& qbatch,
|
||||||
|
ThreadingContext& ctx) {
|
||||||
|
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention");
|
||||||
|
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
||||||
|
layer, activations, ctx);
|
||||||
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
|
// heads that share the same key and value heads.
|
||||||
|
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||||
|
|
||||||
|
using DF = hn::ScalableTag<float>;
|
||||||
|
const DF df;
|
||||||
|
constexpr size_t kVTileSize = hn::MaxLanes(df);
|
||||||
|
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||||
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
|
||||||
|
const size_t total_tasks = token_batch * layer_config.heads;
|
||||||
|
// q has shape [batch, qbatch][head, qkv_dim].
|
||||||
|
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
|
||||||
|
// maximum possible number of consecutive columns have the same KV matrices.
|
||||||
|
// Each thread will process a tile of NF columns of QT so the starting column
|
||||||
|
// index of QT is just the task index * kVTileSize.
|
||||||
|
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
|
||||||
|
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
|
||||||
|
const hwy::Divisor div_tokens(num_tokens);
|
||||||
|
// All layers should have the same number of heads.
|
||||||
|
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
||||||
|
|
||||||
|
// For each head/token/query, compute fused flash Q.K, softmax and weighted V.
|
||||||
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
|
// Offsets into original Q for each row in the tile.
|
||||||
|
uint32_t q_offsets[kVTileSize];
|
||||||
|
// Offsets into att_out for each row in the tile.
|
||||||
|
uint32_t out_offsets[kVTileSize];
|
||||||
|
// Start positions for each row in the tile.
|
||||||
|
size_t start_positions[kVTileSize];
|
||||||
|
// Last positions for each row in the tile. Inclusive.
|
||||||
|
uint32_t last_pos[kVTileSize];
|
||||||
|
// min and max last positions across all rows in the tile determines when
|
||||||
|
// TileFlashAttention switches to single vector mode to handle the
|
||||||
|
// ragged sequence lengths.
|
||||||
|
size_t min_last_pos = std::numeric_limits<size_t>::max();
|
||||||
|
size_t max_last_pos = 0;
|
||||||
|
// Indices into the qbatch.KV for each row in the tile.
|
||||||
|
size_t qi_indices[kVTileSize];
|
||||||
|
// Indices into the kv_cache for each row in the tile.
|
||||||
|
size_t kv_offsets[kVTileSize];
|
||||||
|
// first_task is [qbatch, head, token].
|
||||||
|
const size_t first_task = task * kVTileSize;
|
||||||
|
const size_t last_task = first_task + kVTileSize - 1;
|
||||||
|
bool use_tile_attention = last_task < total_tasks;
|
||||||
|
for (size_t offset = 0;
|
||||||
|
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
|
||||||
|
const size_t batch_idx = div_tokens.Remainder(first_task + offset);
|
||||||
|
const size_t qh = div_tokens.Divide(first_task + offset);
|
||||||
|
const size_t head = activations.div_heads.Remainder(qh);
|
||||||
|
const size_t qi = activations.div_heads.Divide(qh);
|
||||||
|
const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi;
|
||||||
|
qi_indices[offset] = qi;
|
||||||
|
|
||||||
|
// Find the token position in the query and calculate
|
||||||
|
// the range of cache positions to attend to.
|
||||||
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
|
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||||
|
start_positions[offset] = start_pos;
|
||||||
|
size_t last = pos;
|
||||||
|
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||||
|
if (prefix_end > 0 && prefix_end - 1 > last) {
|
||||||
|
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||||
|
last = prefix_end - 1;
|
||||||
|
}
|
||||||
|
last_pos[offset] = last;
|
||||||
|
min_last_pos = HWY_MIN(min_last_pos, last);
|
||||||
|
max_last_pos = HWY_MAX(max_last_pos, last);
|
||||||
|
q_offsets[offset] =
|
||||||
|
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
|
||||||
|
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
|
||||||
|
activations.att_out.Row(0);
|
||||||
|
const size_t kv_index = head / kHeadGroups;
|
||||||
|
const size_t head_offset = kv_index * qkv_dim * 2;
|
||||||
|
kv_offsets[offset] = layer_idx * cache_layer_size + head_offset;
|
||||||
|
// If any of the parameters in this if statement differ within this task,
|
||||||
|
// then we can't use TileFlashAttention. TileFlashAttention requires that
|
||||||
|
// all rows in the tile have the same K and V matrices, and Q starts at
|
||||||
|
// the same position. The end positions do not have to be the equal.
|
||||||
|
if (start_positions[offset] != start_positions[0] ||
|
||||||
|
qi_indices[offset] != qi_indices[0] ||
|
||||||
|
kv_offsets[offset] != kv_offsets[0]) {
|
||||||
|
use_tile_attention = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t offset = 0;
|
||||||
|
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
|
||||||
|
auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache;
|
||||||
|
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride());
|
||||||
|
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim,
|
||||||
|
kv_cache.Stride());
|
||||||
|
if (use_tile_attention) {
|
||||||
|
// To avoid duplicating the code to setup K and V, the call to
|
||||||
|
// TileFlashAttention is inside the loop over tasks, even thought it
|
||||||
|
// handles all rows in the task at once.
|
||||||
|
StridedView<float> qT =
|
||||||
|
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||||
|
activations.q_T.Stride());
|
||||||
|
TileFlashAttention(
|
||||||
|
activations.q, q_offsets, qT, k, start_positions[offset], last_pos,
|
||||||
|
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
|
||||||
|
activations.att_out, out_offsets, ctx.profiler, worker);
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||||
|
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||||
|
layer_idx, layer, activations,
|
||||||
|
activations.att_out.Row(0) + out_offsets[offset],
|
||||||
|
ctx.profiler, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||||
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
|
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
|
||||||
|
|
||||||
|
// Declares FlashAttention for all SIMD targets.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||||
|
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||||
|
namespace NAMESPACE { \
|
||||||
|
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
|
||||||
|
MatPtrT<KV_t>& q, size_t layer_idx, \
|
||||||
|
const LayerWeightsPtrs& layer, \
|
||||||
|
const AttentionActivations& activations, \
|
||||||
|
ThreadingContext& ctx); \
|
||||||
|
\
|
||||||
|
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||||
|
const float* HWY_RESTRICT q, \
|
||||||
|
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
|
const AttentionActivations& activations, \
|
||||||
|
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
|
||||||
|
size_t worker); \
|
||||||
|
\
|
||||||
|
void FlashAttention(size_t num_tokens, size_t layer_idx, \
|
||||||
|
const LayerWeightsPtrs& layer, \
|
||||||
|
AttentionActivations& activations, QBatch& qbatch, \
|
||||||
|
ThreadingContext& ctx); \
|
||||||
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
// Function declarations for each SIMD target. Allows direct call from the
|
||||||
|
// per-target namespace. We may later replace this with dynamic dispatch if
|
||||||
|
// the overhead is acceptable.
|
||||||
|
HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
|
||||||
|
|
||||||
|
#undef GEMMA_DECL_FLASH_ATTENTION
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
|
||||||
|
|
@ -0,0 +1,171 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/types.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
|
#include "gemma/weights.h"
|
||||||
|
#include "ops/matmul.h"
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <algorithm> // std::max
|
||||||
|
#include <cmath> // std::abs
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "gemma/flash_attention_test.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
#include "gemma/attention.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/flash_attention.h"
|
||||||
|
#include "ops/matvec-inl.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||||
|
|
||||||
|
void SetMat(const size_t offset, MatPtrT<float>& mat) {
|
||||||
|
const size_t kOuter = mat.Extents().rows;
|
||||||
|
const size_t kInner = mat.Extents().cols;
|
||||||
|
const float i_scale = 1.0f / kInner;
|
||||||
|
const float j_scale = 1.0f / kOuter;
|
||||||
|
for (size_t i = 0; i < kOuter; ++i) {
|
||||||
|
float* row = mat.Row(i);
|
||||||
|
for (size_t j = 0; j < kInner; ++j) {
|
||||||
|
row[j] =
|
||||||
|
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<MatStorageT<float>> MakeCopyOfMat(const MatPtrT<float>& mat,
|
||||||
|
const Allocator& allocator) {
|
||||||
|
auto copy = std::make_unique<MatStorageT<float>>("TestMat", mat.Extents(),
|
||||||
|
allocator, MatPadding::kOdd);
|
||||||
|
CopyMat(mat, *copy);
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
|
||||||
|
// Avoid comparing the padding bytes, which are uninitialized.
|
||||||
|
for (size_t r = 0; r < a.Rows(); ++r) {
|
||||||
|
const float* HWY_RESTRICT a_row = a.Row(r);
|
||||||
|
const float* HWY_RESTRICT b_row = b.Row(r);
|
||||||
|
for (size_t c = 0; c < a.Cols(); ++c) {
|
||||||
|
float rel_abs_delta = std::abs(a_row[c] - b_row[c]);
|
||||||
|
if (rel_abs_delta > 0.0f) {
|
||||||
|
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
|
||||||
|
}
|
||||||
|
EXPECT_LT(rel_abs_delta, 1e-5)
|
||||||
|
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
|
||||||
|
<< c << "]=" << b_row[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAttention() {
|
||||||
|
ThreadingArgs threading_args;
|
||||||
|
ThreadingContext ctx(threading_args);
|
||||||
|
// hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||||
|
constexpr size_t kOuter = 1024;
|
||||||
|
constexpr size_t kInner = 256;
|
||||||
|
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
|
||||||
|
TensorInfoRegistry tensor_info_registry(config);
|
||||||
|
const LayerConfig& layer_config = config.layer_configs[0];
|
||||||
|
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
|
||||||
|
InferenceArgs inference_args;
|
||||||
|
RuntimeConfig runtime_config;
|
||||||
|
KVCache kv_cache(config, inference_args, ctx.allocator);
|
||||||
|
MatMulEnv env(ctx);
|
||||||
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
|
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||||
|
std::vector<int> tokens(kOuter);
|
||||||
|
std::iota(tokens.begin(), tokens.end(), 1);
|
||||||
|
PromptTokens prompt(tokens);
|
||||||
|
AllQueries all_queries(hwy::Span<const PromptTokens>(&prompt, 1),
|
||||||
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
|
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
||||||
|
const size_t batch_size = kOuter;
|
||||||
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
|
AttentionActivations attention(config, layer_config, batch_size, kOuter,
|
||||||
|
ctx.allocator, row_ptrs);
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
ASSERT_EQ(qkv_dim, kInner);
|
||||||
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
|
// heads that share the same key and value heads.
|
||||||
|
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||||
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(attention.div_seq_len.GetDivisor());
|
||||||
|
auto& kvc = qbatch.KV(0).kv_cache;
|
||||||
|
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||||
|
// Make strided views into the kv cache for
|
||||||
|
// this query and head.
|
||||||
|
const size_t head_offset = (h / kHeadGroups) * qkv_dim * 2;
|
||||||
|
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
k.SetPtr(kvc.Row(0) + head_offset, kvc.Stride());
|
||||||
|
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
|
||||||
|
SetMat(h + layer_config.heads, k);
|
||||||
|
SetMat(h + layer_config.heads * 2, v);
|
||||||
|
}
|
||||||
|
SetMat(1, attention.q);
|
||||||
|
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
|
||||||
|
// Copy the output to saved_att to allow for comparison.
|
||||||
|
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||||
|
SetMat(1, attention.q);
|
||||||
|
FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx);
|
||||||
|
AssertClose(attention.att_out, *saved_att);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
HWY_BEFORE_TEST(FlashAttentionTest);
|
||||||
|
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
|
||||||
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif
|
||||||
345
ops/ops-inl.h
345
ops/ops-inl.h
|
|
@ -613,6 +613,351 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
|
||||||
|
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
|
||||||
|
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
|
||||||
|
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
|
||||||
|
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
|
||||||
|
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
|
||||||
|
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
|
||||||
|
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
|
||||||
|
sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4);
|
||||||
|
sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5);
|
||||||
|
sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6);
|
||||||
|
sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7);
|
||||||
|
sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8);
|
||||||
|
sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9);
|
||||||
|
sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10);
|
||||||
|
sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11);
|
||||||
|
sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12);
|
||||||
|
sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13);
|
||||||
|
sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14);
|
||||||
|
sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
|
||||||
|
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||||
|
VF& sum4, VF& sum5, VF& sum6,
|
||||||
|
VF& sum7) {
|
||||||
|
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
|
||||||
|
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
|
||||||
|
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
|
||||||
|
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
|
||||||
|
sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4);
|
||||||
|
sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5);
|
||||||
|
sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6);
|
||||||
|
sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
|
||||||
|
VF& sum0, VF& sum1, VF& sum2,
|
||||||
|
VF& sum3) {
|
||||||
|
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
|
||||||
|
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
|
||||||
|
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
|
||||||
|
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
|
||||||
|
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
|
||||||
|
// after first prescaling out by scale.
|
||||||
|
// The depth (size) must be a multiple of NF.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
||||||
|
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||||
|
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
|
||||||
|
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||||
|
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||||
|
hwy::Profiler& p, const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + NF <= size) {
|
||||||
|
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||||
|
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||||
|
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||||
|
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||||
|
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||||
|
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||||
|
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||||
|
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||||
|
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||||
|
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||||
|
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||||
|
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||||
|
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||||
|
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
|
||||||
|
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
|
||||||
|
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
|
||||||
|
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
|
||||||
|
out8 = hn::Mul(out8, hn::Set(df, scale.raw[8]));
|
||||||
|
out9 = hn::Mul(out9, hn::Set(df, scale.raw[9]));
|
||||||
|
out10 = hn::Mul(out10, hn::Set(df, scale.raw[10]));
|
||||||
|
out11 = hn::Mul(out11, hn::Set(df, scale.raw[11]));
|
||||||
|
out12 = hn::Mul(out12, hn::Set(df, scale.raw[12]));
|
||||||
|
out13 = hn::Mul(out13, hn::Set(df, scale.raw[13]));
|
||||||
|
out14 = hn::Mul(out14, hn::Set(df, scale.raw[14]));
|
||||||
|
out15 = hn::Mul(out15, hn::Set(df, scale.raw[15]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||||
|
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||||
|
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||||
|
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||||
|
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||||
|
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||||
|
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||||
|
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||||
|
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||||
|
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||||
|
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||||
|
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||||
|
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||||
|
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||||
|
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||||
|
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||||
|
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||||
|
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||||
|
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||||
|
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||||
|
}
|
||||||
|
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||||
|
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||||
|
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||||
|
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||||
|
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
|
||||||
|
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
|
||||||
|
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
|
||||||
|
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||||
|
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||||
|
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||||
|
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||||
|
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||||
|
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||||
|
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||||
|
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||||
|
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||||
|
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||||
|
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||||
|
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||||
|
}
|
||||||
|
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||||
|
VF out0, out1, out2, out3;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||||
|
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||||
|
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||||
|
MulAdd4(df, x1, c1, out0, out1, out2, out3);
|
||||||
|
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||||
|
MulAdd4(df, x2, c2, out0, out1, out2, out3);
|
||||||
|
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||||
|
MulAdd4(df, x3, c3, out0, out1, out2, out3);
|
||||||
|
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||||
|
MulAdd4(df, x4, c4, out0, out1, out2, out3);
|
||||||
|
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||||
|
MulAdd4(df, x5, c5, out0, out1, out2, out3);
|
||||||
|
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||||
|
MulAdd4(df, x6, c6, out0, out1, out2, out3);
|
||||||
|
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||||
|
MulAdd4(df, x7, c7, out0, out1, out2, out3);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
}
|
||||||
|
i += NF;
|
||||||
|
}
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
HWY_DASSERT(remaining == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
|
||||||
|
// corresponding values in c0 and adds them to the NF rows of out.
|
||||||
|
// The depth (size) must be a multiple of NF.
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||||
|
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
|
||||||
|
const size_t pos, float* HWY_RESTRICT out,
|
||||||
|
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||||
|
hwy::Profiler& p, const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const size_t NF = hn::MaxLanes(df);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + NF <= size) {
|
||||||
|
if constexpr (NF == 16) {
|
||||||
|
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||||
|
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||||
|
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||||
|
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||||
|
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||||
|
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||||
|
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||||
|
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||||
|
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||||
|
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||||
|
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||||
|
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||||
|
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
|
||||||
|
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
|
||||||
|
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
|
||||||
|
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
|
||||||
|
out8 = hn::Mul(out8, hn::Set(df, scale.raw[8]));
|
||||||
|
out9 = hn::Mul(out9, hn::Set(df, scale.raw[9]));
|
||||||
|
out10 = hn::Mul(out10, hn::Set(df, scale.raw[10]));
|
||||||
|
out11 = hn::Mul(out11, hn::Set(df, scale.raw[11]));
|
||||||
|
out12 = hn::Mul(out12, hn::Set(df, scale.raw[12]));
|
||||||
|
out13 = hn::Mul(out13, hn::Set(df, scale.raw[13]));
|
||||||
|
out14 = hn::Mul(out14, hn::Set(df, scale.raw[14]));
|
||||||
|
out15 = hn::Mul(out15, hn::Set(df, scale.raw[15]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||||
|
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||||
|
out9, out10, out11, out12, out13, out14, out15);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||||
|
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||||
|
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||||
|
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||||
|
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||||
|
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||||
|
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||||
|
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||||
|
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||||
|
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||||
|
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||||
|
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||||
|
} else if constexpr (NF == 8) {
|
||||||
|
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||||
|
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||||
|
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||||
|
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
|
||||||
|
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
|
||||||
|
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
|
||||||
|
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||||
|
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||||
|
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||||
|
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||||
|
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||||
|
} else if constexpr (NF == 4) {
|
||||||
|
VF out0, out1, out2, out3;
|
||||||
|
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||||
|
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||||
|
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||||
|
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||||
|
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
|
||||||
|
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
|
||||||
|
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
|
||||||
|
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
|
||||||
|
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||||
|
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||||
|
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||||
|
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||||
|
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||||
|
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||||
|
} else {
|
||||||
|
HWY_DASSERT(false);
|
||||||
|
}
|
||||||
|
i += NF;
|
||||||
|
}
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
HWY_DASSERT(remaining == 0);
|
||||||
|
}
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
// TODO: support bf16 logits using Decompress2.
|
// TODO: support bf16 logits using Decompress2.
|
||||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue