From d6d24487c213a8742325621cf5371b699f83b9eb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 16:52:46 -0500 Subject: [PATCH] fixed a bug of not bound checking batch dimension --- ggml/src/ggml-cuda/conv3d-implicit.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index d7d8ef1086..c226351388 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -180,7 +180,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && + n < param.n && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -249,7 +250,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d - && curC < param.c && kidx < end_k){ + && n < param.n && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{