gemma.cpp/gemma/tiled_attention.cc

661 lines
30 KiB
C++

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstring>
#include <iostream>
#include <limits>
#include <utility>
#include <vector>
#include "compression/compress.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "ops/matmul.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// Note: HWY_DISABLED_TARGETS needs to be defined the same everywhere.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/tiled_attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/attention.h"
#include "gemma/flash_attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
static HWY_INLINE void MergeOnlineSoftmax(
const float* HWY_RESTRICT other_att_out, const float other_softmax_max,
const float other_softmax_d, int qkv_dim,
float* HWY_RESTRICT accumulator_att_out, float& accumulator_softmax_max,
float& accumulator_softmax_d) {
if (other_softmax_d == 0.0f) {
return;
}
if (accumulator_softmax_d == 0.0f) {
memcpy(accumulator_att_out, other_att_out,
qkv_dim * sizeof(*accumulator_att_out));
accumulator_softmax_max = other_softmax_max;
accumulator_softmax_d = other_softmax_d;
return;
}
const float m_new = std::max(accumulator_softmax_max, other_softmax_max);
const float exp_l = std::exp(accumulator_softmax_max - m_new);
const float exp_r = std::exp(other_softmax_max - m_new);
const float d_new = accumulator_softmax_d * exp_l + other_softmax_d * exp_r;
const float d_new_inv = 1.0f / d_new;
const float c1 = accumulator_softmax_d * exp_l * d_new_inv;
const float c2 = other_softmax_d * exp_r * d_new_inv;
MulByConst(c1, accumulator_att_out, qkv_dim);
MulByConstAndAdd(c2, other_att_out, accumulator_att_out, qkv_dim);
accumulator_softmax_max = m_new;
accumulator_softmax_d = d_new;
}
// Forked from ComputeQKV. But it stores the K/V in the tiled format
// KV_T is type stored in the KV cache (typically float or BF16).
template <typename KV_T>
static HWY_INLINE void ComputeQKVTransposedTile(
size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionImpl attention_impl, AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKVTiled");
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
// This computes Q and stores it in activations.q.
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
// This computes Q and stores it in activations.q.
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
/*add=*/nullptr, env, activations.q);
// Compute the combined KV output from pre_att_rms_out.
// The output shape is [num_interleaved, kv_heads * 2 * qkv_dim].
const size_t kv_out_cols = kv_heads * 2 * qkv_dim;
hwy::AlignedFreeUniquePtr<float[]> kv_out_mem =
hwy::AllocateAligned<float>(num_interleaved * kv_out_cols);
float* kv_out_data = kv_out_mem.get();
MatPtrT<float> kv_out_mat("kv_out", Extents2D(num_interleaved, kv_out_cols));
kv_out_mat.SetPtr(kv_out_data, kv_out_cols);
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_out_mat);
// Apply positional encodings and store K/V in tiled format.
hwy::Divisor div_kv_heads(kv_heads);
hn::ScalableTag<float> df;
static hwy::Divisor tile_size_divisor(KVCache::kTileSize);
ParallelFor(
Parallelism::kFlat, kv_heads * qbatch.Size(), env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t kv_head = div_kv_heads.Remainder(task);
const size_t query_idx = div_kv_heads.Divide(task);
CompressPerThread tls;
size_t current_token_idx = 0;
float* k_tile_vec = activations.k_tile_vec.Row(task);
float* v_tile_vec = activations.v_tile_vec.Row(task);
HWY_ALIGN float k_f32[kMaxQKVDim];
const size_t start_pos = qbatch.Pos(query_idx);
const bool is_global_layer =
activations.config.IsGlobalLayer(layer_idx);
std::vector<MatPtr> kv_ptrs =
qbatch.KV(query_idx).cache->GetPointers(
layer_idx, kv_head, kv_heads, start_pos, is_global_layer);
size_t tile_offset = 0;
if (!is_global_layer) {
tile_offset = start_pos / KVCache::kTileSize;
}
while (current_token_idx < num_tokens) {
const size_t pos = start_pos + current_token_idx;
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
const size_t tile_idx = tile_size_divisor.Divide(pos_mod);
const size_t relative_tile_idx = tile_idx - tile_offset;
KV_T* tile_ptr;
int kv_ptr_idx = 0;
size_t absolute_rows = 0;
while (absolute_rows + kv_ptrs[kv_ptr_idx].Rows() <=
relative_tile_idx) {
absolute_rows += kv_ptrs[kv_ptr_idx].Rows();
kv_ptr_idx++;
}
tile_ptr = HWY_RCAST_ALIGNED(
KV_T*,
kv_ptrs[kv_ptr_idx].RowBytes(relative_tile_idx - absolute_rows));
PackedSpan<KV_T> tile_packed_span{tile_ptr,
2 * qkv_dim * KVCache::kTileSize};
DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec,
qkv_dim * KVCache::kTileSize);
DecompressAndZeroPad(df, tile_packed_span,
qkv_dim * KVCache::kTileSize, v_tile_vec,
qkv_dim * KVCache::kTileSize);
size_t token_in_tile_idx = current_token_idx;
while (token_in_tile_idx < num_tokens) {
const size_t current_pos =
qbatch.Pos(query_idx) + token_in_tile_idx;
const size_t current_pos_mod =
activations.div_seq_len.Remainder(current_pos);
if (tile_size_divisor.Divide(current_pos_mod) != tile_idx) {
break; // Moved to next tile
}
const float* kv_row =
kv_out_data +
(token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols;
const float* k_ptr = kv_row + kv_head * 2 * qkv_dim;
const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim;
hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float));
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32,
qkv_dim, env.ctx, worker);
});
}
PositionalEncodingQK(
k_f32, layer_idx, activations, env.ctx, worker,
current_pos ,
/*mul=*/1.0f);
const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize;
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
const int in_tile_idx_mod_2 = in_tile_idx % 2;
for (int dim = 0; dim < qkv_dim; dim += 2) {
const int dim_mod_2 = dim % 2;
// Pack k's in pairs in preparation for BF16 dot product.
// See flash_attention.cc
// QDotKTilexUpTo4TransposedKDoubleWidthBF16
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
in_tile_idx * 2] = k_f32[dim];
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
in_tile_idx * 2 + 1] = k_f32[dim + 1];
// Pack v's in pairs
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
dim * 2 + in_tile_idx_mod_2] = v_ptr[dim];
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
(dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1];
}
} else {
for (int i = 0; i < qkv_dim; ++i) {
k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i];
}
Compress(v_ptr, qkv_dim, tls, tile_packed_span,
qkv_dim * (KVCache::kTileSize + in_tile_idx));
}
token_in_tile_idx++;
}
Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, 0);
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, qkv_dim * KVCache::kTileSize);
}
current_token_idx = token_in_tile_idx;
}
});
}
// TODO: optimize with gathers
// This format might change in the future, when kernel will be updated to
// support more than 8 queries.
// Input (num_queries, qkv_dim)
// Output (qkv_dim, num_queries)
void TransposeQ(const MatPtrT<float>& queries,
hwy::Span<float> transposed_queries_span) {
const size_t qkv_dim = queries.Cols();
const size_t num_queries = queries.Rows();
HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim);
for (size_t i = 0; i < qkv_dim; i++) {
for (size_t j = 0; j < num_queries; ++j) {
transposed_queries_span[i * num_queries + j] = queries.Row(j)[i];
}
}
}
// Transposes queries
// Input: vector of pointers to subsequent queries. (allows for arbitrary
// strides)
// qkv_dim: dimension of query
// allocator: aligned allocator to use for temporary storage
//
// Output: Pointer to contiguous memory with shape (qkv_dim,
// queries.size())
void TransposeStridedQueries(
hwy::Span<float*> queries, int qkv_dim,
hwy::Span<float> transposed_queries) {
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
using DI = hn::ScalableTag<int32_t>;
const DI di;
using VI = hn::Vec<DI>;
const size_t lanes = hn::Lanes(df);
const size_t num_queries = queries.size();
const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, lanes);
std::vector<int32_t, hwy::AlignedAllocator<int32_t>> query_offsets(
num_queries_rounded_up);
for (size_t i = 0; i < num_queries; ++i) {
query_offsets[i] = queries[i] - queries[0];
}
for (size_t i = num_queries; i < num_queries_rounded_up; ++i) {
// last offset is the same so gather doesn't read out of bounds
query_offsets[i] = query_offsets[num_queries - 1];
}
for (size_t i = 0; i < qkv_dim; i++) {
size_t j = 0;
if (num_queries >= lanes) {
for (; j <= num_queries-lanes; j += lanes) {
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
hn::StoreU(x, df, transposed_queries.data() + i * num_queries + j);
}
}
if (j < num_queries) {
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
hn::StoreN(x, df, transposed_queries.data() + i * num_queries + j,
num_queries - j);
}
}
}
std::pair<AlignedFloatVector, std::vector<float*>> TransposeQueriesToGroupsOf4(
hwy::Span<float*> queries_ptrs, int qkv_dim) {
int num_queries = queries_ptrs.size();
int num_groups = hwy::DivCeil(num_queries, 4);
AlignedFloatVector transposed_queries(num_groups * 4 * qkv_dim);
std::vector<float*> transposed_queries_ptrs;
for (int group_idx = 0; group_idx < num_groups; ++group_idx){
int group_size = std::min(4, num_queries - group_idx * 4);
transposed_queries_ptrs.push_back(transposed_queries.data() +
group_idx * qkv_dim * 4);
TransposeStridedQueries(
hwy::Span<float*>(queries_ptrs.data() + group_idx * 4,
group_size),
qkv_dim,
hwy::Span<float>(transposed_queries_ptrs.back(), qkv_dim * group_size));
}
return std::make_pair(std::move(transposed_queries),
std::move(transposed_queries_ptrs));
}
std::pair<AlignedBF16Vector, std::vector<BF16*>>
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
int qkv_dim, int num_queries) {
constexpr int kMaxGroupSize = 4;
int num_groups = queries_ptrs.size();
AlignedBF16Vector transposed_queries(num_groups * kMaxGroupSize * qkv_dim);
std::vector<BF16*> transposed_queries_ptrs;
transposed_queries_ptrs.reserve(num_groups);
for (int group_idx = 0; group_idx < num_groups; ++group_idx) {
int group_size =
std::min(kMaxGroupSize, num_queries - group_idx * kMaxGroupSize);
transposed_queries_ptrs.push_back(transposed_queries.data() +
group_idx * qkv_dim * kMaxGroupSize);
for (int dim_idx = 0; dim_idx < qkv_dim; dim_idx += 2) {
for (int query_idx = 0; query_idx < group_size; ++query_idx) {
transposed_queries_ptrs.back()[dim_idx * group_size + query_idx * 2] =
hwy::ConvertScalarTo<BF16>(
queries_ptrs[group_idx][dim_idx * group_size + query_idx]);
transposed_queries_ptrs
.back()[dim_idx * group_size + query_idx * 2 + 1] =
hwy::ConvertScalarTo<BF16>(
queries_ptrs[group_idx]
[(dim_idx + 1) * group_size + query_idx]);
}
}
}
return std::make_pair(std::move(transposed_queries),
std::move(transposed_queries_ptrs));
}
template <typename T>
static HWY_INLINE void MaybeResizeMatStorage(MatStorageT<T>& mat_storage,
int rows, int cols,
const char* name,
const Allocator& allocator) {
if (mat_storage.Rows() != rows || mat_storage.Cols() != cols) {
mat_storage = MatStorageT<T>(name, Extents2D(rows, cols), allocator,
MatPadding::kOdd);
}
}
// clang-format off
// Schedules TiledFlashAttention for all heads, tokens and batch.
// Returns partial results in the same order as queries in `activations.q`.
// Might not work yet for prefix lm.
// To help understanding how to use this function below is description of how
// parameters are used:
//
// attention_impl - Used to determine attention kernel to use.
// num_query_tokens - number of tokens/timesteps in processed in a single batch
// it will influence how many queries kvs are evaluated against.
// num_kv_tokens - number of tokens/timesteps in kv cache
// layer_idx - layer index
// layer - used to get kv_heads, heads, qkv_dim
// activations - reads: activations.q queries, att_cap, IsGlobalLayer
// qbatch - kv cache, Pos / EndPrefix
// ctx - threading context
// clang-format on
void LocalAttentionForAllHeadsTokensAndBatch(
AttentionImpl attention_impl, const size_t num_query_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
const size_t heads_per_kv_head =
layer.layer_config.heads / layer.layer_config.kv_heads;
int core_count = ctx.pools.MaxWorkers();
int task_multiplier = 1;
while (qbatch.Size() * layer.layer_config.kv_heads * task_multiplier <
core_count * 2) {
task_multiplier++;
}
// Finding the smallest context we need to attend to avoid unnecessary
// overhead when sub-splitting doesn't make sense. This check overestimates
// context sizes because it ignores [local] layer sizes and explicit
// qbatch.Prefix settings.
size_t min_pos = qbatch.Pos(0);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
min_pos = std::min(min_pos, qbatch.Pos(qi));
}
if (min_pos / task_multiplier < num_query_tokens) {
// In case where min_pos / task_multiplier < num_tokens
// To make sure we don't over count tokens or read out of bounds code
// requires quite a bit more involved logic.
// Also there is not much point to splitting the work into more tasks, when
// amount of work is small.
task_multiplier = 1;
}
[[maybe_unused]] int num_tasks = qbatch.Size() * layer.layer_config.kv_heads;
[[maybe_unused]] int num_sub_tasks =
qbatch.Size() * layer.layer_config.kv_heads * task_multiplier;
HWY_DASSERT_M(activations.q.Rows() == num_query_tokens * qbatch.Size(),
"qbatch size mismatch");
int qkv_dim = layer.layer_config.qkv_dim;
// sizes of all should be in sync
if (num_sub_tasks > activations.sub_task_att_out->size()) {
activations.sub_task_att_out->resize(num_sub_tasks);
activations.sub_task_exp_denominator_sums->resize(num_sub_tasks);
activations.sub_task_max_logits->resize(num_sub_tasks);
}
std::vector<int> skip_sub_task(num_sub_tasks, 0);
// This loop parallelizes over qbatch, kv_head and substrings of context
// tokens. Each parallel invocation handles all query tokens of the given
// qbatch.
ParallelFor(
Parallelism::kHierarchical, num_sub_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t task_idx, size_t worker) HWY_ATTR {
size_t main_task_idx = task_idx / task_multiplier;
size_t sub_task_idx = task_idx % task_multiplier;
size_t current_qbatch_idx =
main_task_idx / layer.layer_config.kv_heads;
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
// First and last context token we will attend to.
size_t global_start_context_pos = StartPos(
qbatch.Pos(current_qbatch_idx), activations.config, layer_idx);
// Keep in mind this is overestimation because some timesteps might not
// need all tokens due to causal mask.
// We will use it to determine how to divide work between sub tasks
// and make sure PrefixEnd is taken into account
size_t start_context_pos = global_start_context_pos;
size_t last_context_pos =
qbatch.Pos(current_qbatch_idx) + num_query_tokens - 1;
// In some models, context is limited to some prefix - make sure we take
// that into account.
const size_t prefix_end = qbatch.PrefixEnd(current_qbatch_idx);
if (prefix_end > 0 && prefix_end - 1 > last_context_pos) {
last_context_pos = prefix_end - 1;
}
size_t total_num_context_tokens =
last_context_pos - start_context_pos + 1;
size_t context_tokens_per_sub_task =
hwy::DivCeil(total_num_context_tokens, task_multiplier);
// Restrict tokens to attend to the substring of context tokens that
// this subtask is responsible for.
start_context_pos =
start_context_pos + context_tokens_per_sub_task * sub_task_idx;
if (start_context_pos > last_context_pos) {
skip_sub_task[task_idx] = 1;
return;
}
last_context_pos =
std::min(last_context_pos,
start_context_pos + context_tokens_per_sub_task - 1);
// pre-initialize memory [to avoid racy resizes laters].
int num_queries = num_query_tokens * heads_per_kv_head;
std::vector<float*> queries_ptrs;
queries_ptrs.reserve(num_queries);
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
++q_head_idx) {
queries_ptrs.push_back(
activations.q.Row(token_idx * qbatch.Size() +
current_qbatch_idx) +
(kv_head_idx * heads_per_kv_head + q_head_idx) * qkv_dim);
}
}
hwy::Span<float*> queries_ptrs_span(queries_ptrs.data(),
queries_ptrs.size());
auto [transposed_queries, transposed_queries_ptrs] =
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
MatStorageT<float>& att_out =
activations.sub_task_att_out->at(task_idx);
AlignedFloatVector& exp_denominator_sums =
activations.sub_task_exp_denominator_sums->at(task_idx);
AlignedFloatVector& max_logits =
activations.sub_task_max_logits->at(task_idx);
MaybeResizeMatStorage(att_out, num_queries, qkv_dim, "att_out",
ctx.allocator);
for (int i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
att_out.Cols() * sizeof(decltype(att_out.Row(i)[0])));
}
int num_queries_rounded_to_8 = hwy::RoundUpTo(num_queries, 8);
exp_denominator_sums.resize(num_queries_rounded_to_8);
max_logits.resize(num_queries_rounded_to_8);
for (int i = 0; i < num_queries_rounded_to_8; ++i) {
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
// Get pointers to the KVCache tiles, starting at global_start_pos
// Returns multiple matrices for non-contiguous memory, for example as a
// result of the wraparound in local layers.
std::vector<MatPtr> kv_ptrs =
qbatch.KV(current_qbatch_idx)
.cache->GetPointers(
layer_idx, kv_head_idx, layer.layer_config.kv_heads,
global_start_context_pos,
activations.config.IsGlobalLayer(layer_idx));
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
// Position of the first token in the first tile whose pointer was
// returned above. Allows for handling of token positions relative to
// the KV tiles returned above.
size_t rounded_down_global_start_pos =
hwy::RoundDownTo(global_start_context_pos, KVCache::kTileSize);
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
int64_t global_query_pos =
qbatch.Pos(current_qbatch_idx) + token_idx;
// Intersect context to attend to for this specific query token
// to the context tokens of the current subtask.
int64_t query_last_context_pos = std::min(
static_cast<int64_t>(last_context_pos), global_query_pos);
// This max is to not go into negative values, for the same reason we
// use int64_t and not size_t here.
int64_t query_start_context_pos = std::max(
global_query_pos -
static_cast<int64_t>(
activations.config.attention_window_sizes[layer_idx]) +
1,
static_cast<int64_t>(start_context_pos));
// Turn token position into KV-tile relative token positions.
query_last_context_pos -= rounded_down_global_start_pos;
query_start_context_pos -= rounded_down_global_start_pos;
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
++q_head_idx) {
start_pos_per_query.push_back(query_start_context_pos);
last_pos_per_query.push_back(query_last_context_pos);
}
}
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
// pack transposed queries into BF16
hwy::Span<float*> queries_span(transposed_queries_ptrs.data(),
transposed_queries_ptrs.size());
auto [_, transposed_queries_ptrs_bf16] =
TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim,
num_queries);
hwy::Span<const BF16*> queries_span_bf16(
const_cast<const BF16**>(transposed_queries_ptrs_bf16.data()),
transposed_queries_ptrs_bf16.size());
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kv_ptrs, num_queries, queries_span_bf16,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
} else {
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kv_ptrs, num_queries,
hwy::Span<const float*>(
const_cast<const float**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size()),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
}
});
// This loop takes results from separate subtasks (subsequence of kv) and
// merges them into single att_out over whole kv sequence.
ParallelFor(
Parallelism::kFlat, num_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t main_task_idx, size_t worker) HWY_ATTR {
size_t current_qbatch_idx = main_task_idx / layer.layer_config.kv_heads;
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
for (int head_in_group_idx = 0; head_in_group_idx < heads_per_kv_head;
++head_in_group_idx) {
const size_t batch_index =
current_qbatch_idx * num_query_tokens + token_idx;
const size_t q_head_idx =
kv_head_idx * heads_per_kv_head + head_in_group_idx;
const size_t att_out_row_idx =
token_idx * heads_per_kv_head + head_in_group_idx;
const size_t activations_att_out_start_idx = q_head_idx * qkv_dim;
auto& att_out_0 = activations.sub_task_att_out->at(
main_task_idx * task_multiplier + 0);
auto& exp_denominator_sums_0 =
activations.sub_task_exp_denominator_sums->at(
main_task_idx * task_multiplier + 0);
auto& max_logits_0 = activations.sub_task_max_logits->at(
main_task_idx * task_multiplier + 0);
hwy::CopyBytes(att_out_0.Row(att_out_row_idx),
activations.att_out.Row(batch_index) +
activations_att_out_start_idx,
qkv_dim * sizeof(float));
activations.softmax_d.Row(batch_index)[q_head_idx] =
exp_denominator_sums_0[token_idx * heads_per_kv_head +
head_in_group_idx];
activations.softmax_max.Row(batch_index)[q_head_idx] =
max_logits_0[token_idx * heads_per_kv_head + head_in_group_idx];
for (int sub_task_idx = 1; sub_task_idx < task_multiplier;
++sub_task_idx) {
int task_idx = main_task_idx * task_multiplier + sub_task_idx;
if (skip_sub_task[task_idx] == 1) {
continue;
}
auto& att_out = activations.sub_task_att_out->at(task_idx);
auto& exp_denominator_sums =
activations.sub_task_exp_denominator_sums->at(task_idx);
auto& max_logits = activations.sub_task_max_logits->at(task_idx);
MergeOnlineSoftmax(
att_out.Row(att_out_row_idx),
max_logits[token_idx * heads_per_kv_head + head_in_group_idx],
exp_denominator_sums[token_idx * heads_per_kv_head +
head_in_group_idx],
qkv_dim,
activations.att_out.Row(batch_index) +
activations_att_out_start_idx,
activations.softmax_max.Row(batch_index)[q_head_idx],
activations.softmax_d.Row(batch_index)[q_head_idx]);
}
}
}
});
}
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
static const auto zone = env.ctx.profiler.AddZone(
"Gen.TiledAttention", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT
if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) {
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
} else {
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
}
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
layer.query_norm_scale, layer_idx, activations,
env.ctx);
LocalAttentionForAllHeadsTokensAndBatch(attention_impl, num_tokens, layer_idx,
layer, activations, qbatch, env.ctx);
SumHeads(layer, activations, env);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();