Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Oliver Simons 2025-12-12 15:07:28 +01:00 committed by GitHub
parent 4d10b78e23
commit 07b809bbc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 8 deletions

View File

@ -190,7 +190,7 @@ static void cumsum_cuda(
if (is_contiguous) { if (is_contiguous) {
use_cub = true; use_cub = true;
int64_t nrows = ne01 * ne02 * ne03; const int64_t nrows = ne01 * ne02 * ne03;
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {

View File

@ -232,10 +232,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
// Compute thread-local max // Compute thread-local max
for (int col = col_start; col < p.ncols;) { for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size; const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
} }
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
local_max = fmaxf(local_max, local_vals[i]); local_max = fmaxf(local_max, local_vals[i]);
} }
@ -263,10 +265,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
// Compute softmax dividends, accumulate divisor // Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f; float tmp_expf = 0.0f;
for (int col = col_start; col < p.ncols;) { for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size; const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
} }
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size; const int idx = col + i * step_size;
if (idx < p.ncols) { if (idx < p.ncols) {
@ -297,10 +301,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
// Divide dividend by global sum + store data // Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) { for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size; const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
} }
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) { for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size; const int idx = col + i * step_size;
if (idx < p.ncols) { if (idx < p.ncols) {
@ -367,7 +373,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p); soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
} }
static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
float * __restrict__ dst, float * __restrict__ dst,
float * __restrict__ tmp_vals, float * __restrict__ tmp_vals,
const soft_max_params p) const soft_max_params p)
@ -408,7 +414,7 @@ static void soft_max_f32_cuda(const float * x,
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else { } else {
// Parallelize across SMs for top-p/dist-smapling // Parallelize across SMs for top-p/dist-sampling
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {