Apply suggestions from code review
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
4d10b78e23
commit
07b809bbc0
|
|
@ -190,7 +190,7 @@ static void cumsum_cuda(
|
||||||
|
|
||||||
if (is_contiguous) {
|
if (is_contiguous) {
|
||||||
use_cub = true;
|
use_cub = true;
|
||||||
int64_t nrows = ne01 * ne02 * ne03;
|
const int64_t nrows = ne01 * ne02 * ne03;
|
||||||
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
|
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
|
||||||
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
|
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
|
||||||
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
|
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
|
||||||
|
|
|
||||||
|
|
@ -232,10 +232,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
|
||||||
|
|
||||||
// Compute thread-local max
|
// Compute thread-local max
|
||||||
for (int col = col_start; col < p.ncols;) {
|
for (int col = col_start; col < p.ncols;) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
const int idx = col + i * step_size;
|
const int idx = col + i * step_size;
|
||||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
local_max = fmaxf(local_max, local_vals[i]);
|
local_max = fmaxf(local_max, local_vals[i]);
|
||||||
}
|
}
|
||||||
|
|
@ -263,10 +265,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
|
||||||
// Compute softmax dividends, accumulate divisor
|
// Compute softmax dividends, accumulate divisor
|
||||||
float tmp_expf = 0.0f;
|
float tmp_expf = 0.0f;
|
||||||
for (int col = col_start; col < p.ncols;) {
|
for (int col = col_start; col < p.ncols;) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
const int idx = col + i * step_size;
|
const int idx = col + i * step_size;
|
||||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
const int idx = col + i * step_size;
|
const int idx = col + i * step_size;
|
||||||
if (idx < p.ncols) {
|
if (idx < p.ncols) {
|
||||||
|
|
@ -297,10 +301,12 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
|
||||||
|
|
||||||
// Divide dividend by global sum + store data
|
// Divide dividend by global sum + store data
|
||||||
for (int col = col_start; col < p.ncols;) {
|
for (int col = col_start; col < p.ncols;) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
const int idx = col + i * step_size;
|
const int idx = col + i * step_size;
|
||||||
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
|
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||||
const int idx = col + i * step_size;
|
const int idx = col + i * step_size;
|
||||||
if (idx < p.ncols) {
|
if (idx < p.ncols) {
|
||||||
|
|
@ -367,7 +373,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
|
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float * __restrict__ tmp_vals,
|
float * __restrict__ tmp_vals,
|
||||||
const soft_max_params p)
|
const soft_max_params p)
|
||||||
|
|
@ -408,7 +414,7 @@ static void soft_max_f32_cuda(const float * x,
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||||
} else {
|
} else {
|
||||||
// Parallelize across SMs for top-p/dist-smapling
|
// Parallelize across SMs for top-p/dist-sampling
|
||||||
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
|
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
|
||||||
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
|
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
|
||||||
if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {
|
if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue