mirror of https://github.com/google/gemma.cpp.git
Avoid transposing Q when it isn't needed
PiperOrigin-RevId: 814187984
This commit is contained in:
parent
fe5a39990e
commit
14244664c8
|
|
@ -51,6 +51,8 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
static constexpr size_t kNFx8HTileSize = 8;
|
||||||
|
|
||||||
// Transposes q into q_t.
|
// Transposes q into q_t.
|
||||||
// Both are 4D tensors stuffed into a 2-D MatPtrT.
|
// Both are 4D tensors stuffed into a 2-D MatPtrT.
|
||||||
// q has shape [batch, qbatch][head, qkv_dim].
|
// q has shape [batch, qbatch][head, qkv_dim].
|
||||||
|
|
@ -191,7 +193,7 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||||
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
||||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
||||||
VF& sum7) {
|
VF& sum7) {
|
||||||
constexpr size_t kHTileSize = 8;
|
constexpr size_t kHTileSize = kNFx8HTileSize;
|
||||||
sum0 = hn::Zero(df);
|
sum0 = hn::Zero(df);
|
||||||
sum1 = hn::Zero(df);
|
sum1 = hn::Zero(df);
|
||||||
sum2 = hn::Zero(df);
|
sum2 = hn::Zero(df);
|
||||||
|
|
@ -268,7 +270,7 @@ void TileFlashAttention(
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
|
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
constexpr int kHTileSize = 8;
|
constexpr int kHTileSize = kNFx8HTileSize;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
const DF df;
|
const DF df;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -365,13 +367,12 @@ void TileFlashAttention(
|
||||||
|
|
||||||
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
|
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
|
||||||
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
||||||
// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read
|
// given by k_offsets[0..NF].
|
||||||
// in consecutive elements, and other columns by adding q_stride.
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||||
const MatPtrT<KV_t>& k, const int32_t* HWY_RESTRICT k_offsets,
|
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
||||||
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p,
|
||||||
VF& sum2, VF& sum3) {
|
const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||||
sum0 = hn::Zero(df);
|
sum0 = hn::Zero(df);
|
||||||
sum1 = hn::Zero(df);
|
sum1 = hn::Zero(df);
|
||||||
sum2 = hn::Zero(df);
|
sum2 = hn::Zero(df);
|
||||||
|
|
@ -383,15 +384,14 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||||
VF q_0 = hn::Set(df, q[0]);
|
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
|
||||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||||
VF q_1 = hn::Set(df, q[1]);
|
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
|
||||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||||
VF q_2 = hn::Set(df, q[2]);
|
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
|
||||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||||
VF q_3 = hn::Set(df, q[3]);
|
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
|
||||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||||
q += q_stride;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -416,10 +416,9 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
void TileFlashAttention4(
|
void TileFlashAttention4(
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||||
const size_t min_last_pos, const size_t max_last_pos,
|
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
|
||||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
hwy::Profiler& p, const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
|
|
@ -445,8 +444,6 @@ void TileFlashAttention4(
|
||||||
float old_d1 = 0.0f;
|
float old_d1 = 0.0f;
|
||||||
float old_d2 = 0.0f;
|
float old_d2 = 0.0f;
|
||||||
float old_d3 = 0.0f;
|
float old_d3 = 0.0f;
|
||||||
const float* HWY_RESTRICT qT_row = qT.Row(0);
|
|
||||||
const size_t qT_stride = qT.Stride();
|
|
||||||
size_t position = start_pos;
|
size_t position = start_pos;
|
||||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||||
int32_t k_offsets[kMaxNF];
|
int32_t k_offsets[kMaxNF];
|
||||||
|
|
@ -456,7 +453,8 @@ void TileFlashAttention4(
|
||||||
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
|
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
|
||||||
}
|
}
|
||||||
VF x0, x1, x2, x3;
|
VF x0, x1, x2, x3;
|
||||||
QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3);
|
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
|
||||||
|
x3);
|
||||||
if (activations.config.att_cap > 0.0f) {
|
if (activations.config.att_cap > 0.0f) {
|
||||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||||
VF cap = hn::Set(df, activations.config.att_cap);
|
VF cap = hn::Set(df, activations.config.att_cap);
|
||||||
|
|
@ -608,12 +606,29 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
|
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
|
||||||
? kNF
|
? kNF
|
||||||
: std::min(kMinTileSize, kMaxEqualK);
|
: std::min(kMinTileSize, kMaxEqualK);
|
||||||
|
// Only transpose Q if we are using tiling.
|
||||||
|
if (kVTileSize == kNF) {
|
||||||
|
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();
|
||||||
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
size_t pos = qbatch.Pos(qi);
|
||||||
|
const size_t start = StartPos(pos, activations.config, layer_idx);
|
||||||
|
pos += num_tokens - 1;
|
||||||
|
const size_t end = qbatch.PrefixEnd(qi);
|
||||||
|
if (end > 0 && end - 1 > pos) {
|
||||||
|
pos = end - 1;
|
||||||
|
}
|
||||||
|
max_last = std::max(max_last, pos);
|
||||||
|
min_start = std::min(min_start, start);
|
||||||
|
}
|
||||||
|
if (max_last - min_start + 1 >= kNFx8HTileSize) {
|
||||||
// q has shape [batch, qbatch][head, qkv_dim].
|
// q has shape [batch, qbatch][head, qkv_dim].
|
||||||
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
|
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
|
||||||
// maximum possible number of consecutive columns have the same KV matrices.
|
// maximum possible number of consecutive columns have the same KV
|
||||||
// Each thread will process a tile of NF columns of QT so the starting column
|
// matrices. Each thread will process a tile of NF columns of QT so the
|
||||||
// index of QT is just the task index * kVTileSize.
|
// starting column index of QT is just the task index * kVTileSize.
|
||||||
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
|
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
|
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
|
||||||
const hwy::Divisor div_tokens(num_tokens);
|
const hwy::Divisor div_tokens(num_tokens);
|
||||||
// All layers should have the same number of heads.
|
// All layers should have the same number of heads.
|
||||||
|
|
@ -699,17 +714,23 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||||
activations.q_T.Stride());
|
activations.q_T.Stride());
|
||||||
if (kVTileSize == kNF) {
|
if (kVTileSize == kNF) {
|
||||||
|
// We can still use TileFlashAttention even if we didn't transpose Q
|
||||||
|
// above. The condition used for transposing Q above is more general
|
||||||
|
// and easier to compute than the condition used within
|
||||||
|
// TileFlashAttention that min_last_pos - start_positions[offset] <
|
||||||
|
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
|
||||||
|
// use qT and some might not, which is why the more general condition
|
||||||
|
// is used above to catch all cases where qT will be used.
|
||||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
start_positions[offset], last_pos, min_last_pos,
|
||||||
max_last_pos, v, layer_idx, layer, activations,
|
max_last_pos, v, layer_idx, layer, activations,
|
||||||
activations.att_out, out_offsets, ctx.profiler,
|
activations.att_out, out_offsets, ctx.profiler,
|
||||||
worker);
|
worker);
|
||||||
} else if (kVTileSize == 4) {
|
} else if (kVTileSize == 4) {
|
||||||
TileFlashAttention4(activations.q, q_offsets, qT, k,
|
TileFlashAttention4(
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
activations.q, q_offsets, k, start_positions[offset], last_pos,
|
||||||
max_last_pos, v, layer_idx, layer, activations,
|
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
|
||||||
activations.att_out, out_offsets, ctx.profiler,
|
activations.att_out, out_offsets, ctx.profiler, worker);
|
||||||
worker);
|
|
||||||
} else {
|
} else {
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue