cuda: optimize SOLVE_TRI using registers and FMAF (#17703)

* ggml-cuda: optimize solve_tri_f32_fast and fix stride handling

- Switch from using shared memory for the RHS/solution matrix to a register-based approach (x_low, x_high), reducing shared memory pressure and bank conflicts.
- Implement explicit `fmaf` instructions for the reduction loop.
- Update kernel arguments to pass strides in bytes rather than elements to align with standard ggml tensor arithmetic (casting to `char *` before addition).
- Remove unused `MAX_K_FAST` definition.

* Small cleanup

* Remove comments in solve_tri.cu

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/solve_tri.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Use const for variables in solve_tri.cu

* Replace fmaf with more readable code

* remove last fmaf

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
wsbagnsv1 2025-12-08 10:41:08 +01:00 committed by GitHub
parent 79d61896d3
commit 5814b4dce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 38 deletions

View File

@ -3,7 +3,6 @@
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
const int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
__syncthreads();
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;
#pragma unroll
for (int row = 0; row < n; ++row) {
for (int row = 0; row < nrows_low; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
if (lane < row) {
sum += sA[row * n + lane] * x_low;
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
if (lane == row) {
x_low = (x_low - sum) / sA[row * n + row];
}
}
__syncthreads();
#pragma unroll
for (int row = half; row < n; ++row) {
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
}
sum = warp_reduce_sum(sum);
if (lane == row - half) {
x_high = (x_high - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
}