mirror of https://github.com/google/gemma.cpp.git
Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines
PiperOrigin-RevId: 814241032
This commit is contained in:
parent
14244664c8
commit
684a0444e9
|
|
@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8;
|
|||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
|
||||
// Group floats by the number of floats in a cache line.
|
||||
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
||||
const size_t num_heads = q.Cols() / q_t.Rows();
|
||||
const size_t batch_size = q.Rows() / qbatch_size;
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
float* HWY_RESTRICT qt_row = q_t.Row(task);
|
||||
for (size_t qi = 0; qi < qbatch_size; ++qi)
|
||||
for (size_t lane = 0; lane < kNF; ++lane) {
|
||||
size_t q_row = task * kNF + lane;
|
||||
if (q_row >= q_t.Rows()) break;
|
||||
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
|
||||
for (size_t qi = 0; qi < qbatch_size; ++qi) {
|
||||
for (size_t h = 0; h < num_heads; ++h) {
|
||||
for (size_t b = 0; b < batch_size; ++b) {
|
||||
qt_row[(qi * num_heads + h) * batch_size + b] =
|
||||
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task];
|
||||
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
{
|
||||
// Better than kFlat.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx,
|
||||
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
|
||||
/*cluster_idx=*/0, func);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue