diff --git a/ggml/src/ggml-cuda/where-id.cu b/ggml/src/ggml-cuda/where-id.cu index 993873462b..2d9130035a 100644 --- a/ggml/src/ggml-cuda/where-id.cu +++ b/ggml/src/ggml-cuda/where-id.cu @@ -2,22 +2,23 @@ static __global__ void where_id_kernel( const float * src0, const int32_t * src1, float * dst, - int64_t ne10, int64_t ne11, int64_t ne12, - size_t nb1, size_t nb2, - size_t nb01, size_t nb02, - size_t nb11, size_t nb12 + int64_t ne10, int64_t ne11, int64_t ne12, int64_t ne13, + size_t nb1, size_t nb2, size_t nb3, + size_t nb01, size_t nb02, size_t nb03, + size_t nb11, size_t nb12, size_t nb13 ) { - const int64_t total_blocks = ne11 * ne12; + const int64_t total_blocks = ne11 * ne12 * ne13; for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { const int64_t i1 = block_idx % ne11; - const int64_t i2 = block_idx / ne11; + const int64_t i2 = (block_idx / ne11) % ne12; + const int64_t i3 = block_idx / (ne11 * ne12); - float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); - const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); - const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12); + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03); + const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12 + i3*nb13); for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) { const int32_t id = src1_row[i0]; @@ -64,14 +65,14 @@ void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { int threads = std::min((int) ne20, 768); // ids - int64_t total_blocks = ne21 * ne22; + int64_t total_blocks = ne21 * ne22 * ne23; int blocks = (int) std::min((int64_t) 65535, total_blocks); where_id_kernel<<>>( src0_d, src2_d, dst_d, - ne20, ne21, ne22, - nb1, nb2, - nb01, nb02, - nb21, nb22 + ne20, ne21, ne22, ne23, + nb1, nb2, nb3, + nb01, nb02, nb03, + nb21, nb22, nb23 ); }