vulkan: Fix data races in coopmat1 mul_mat(_id) (#20084)
* vulkan: Fix data races in coopmat1 mul_mat(_id) Add barriers between coopmat store and regular loads. We sort of got away with this because it was the same subgroup accessing the values, but it's still a race and may not work. * switch to subgroup control barriers
This commit is contained in:
parent
a976ff081b
commit
cd18a50ea5
|
|
@ -377,6 +377,7 @@ void main() {
|
||||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
|
barrier();
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||||
if (row_i >= _ne1) break;
|
if (row_i >= _ne1) break;
|
||||||
|
|
@ -387,6 +388,7 @@ void main() {
|
||||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
barrier();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|
@ -404,18 +406,22 @@ void main() {
|
||||||
// Full coopMat is within bounds, but stride_d is not aligned
|
// Full coopMat is within bounds, but stride_d is not aligned
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
|
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||||
}
|
}
|
||||||
|
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||||
// Partial coopMat is within bounds
|
// Partial coopMat is within bounds
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
|
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue