ggml : use WARP_SIZE/2 for argmax reduction offset
This commit is contained in:
parent
7b1db3d3b7
commit
c2f3f7a20e
|
|
@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
|
||||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
if (val > maxval) {
|
if (val > maxval) {
|
||||||
|
|
@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
|
||||||
argmax = shared_argmax[lane_id];
|
argmax = shared_argmax[lane_id];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
|
||||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
if (val > maxval) {
|
if (val > maxval) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue