Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines

PiperOrigin-RevId: 814241032
This commit is contained in:
Ray Smith 2025-10-02 08:14:37 -07:00 committed by Copybara-Service
parent 14244664c8
commit 684a0444e9
1 changed files with 15 additions and 7 deletions

View File

@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8;
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t, static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) { const size_t qbatch_size, ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
const size_t num_heads = q.Cols() / q_t.Rows(); const size_t num_heads = q.Cols() / q_t.Rows();
const size_t batch_size = q.Rows() / qbatch_size; const size_t batch_size = q.Rows() / qbatch_size;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); PROFILER_ZONE3(ctx.profiler, worker, zone);
float* HWY_RESTRICT qt_row = q_t.Row(task); for (size_t lane = 0; lane < kNF; ++lane) {
for (size_t qi = 0; qi < qbatch_size; ++qi) size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) { for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) { for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] = qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task]; q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
}
}
} }
} }
}; };
{ {
// Better than kFlat. // Better than kFlat.
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx, size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, func); /*cluster_idx=*/0, func);
} }
} }