diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4af7f2ba1d..ca10582d23 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5022,8 +5022,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SOLVE_TRI: case GGML_OP_SCATTER: return true; - case GGML_OP_HADAMARD: - return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_HADAMARD: { + int nh = op->op_params[0]; + return (nh == 64 || nh == 128 || nh == 256) && op->ne[0] % nh == 0 && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + } default: return false; } diff --git a/ggml/src/ggml-cuda/hadamard.cu b/ggml/src/ggml-cuda/hadamard.cu index 5f34d2579d..f03866cb5a 100644 --- a/ggml/src/ggml-cuda/hadamard.cu +++ b/ggml/src/ggml-cuda/hadamard.cu @@ -30,7 +30,7 @@ static __global__ void hadamard_f32(const char * src, char * dst, int ne0, float scale = ksqrt2; #pragma unroll - for (int h = 2; h < nh; h <<= 2) { + for (int h = 2; h < nh; h <<= 1) { __syncthreads(); int ii = tid/h, jj = tid%h; int j = 2*h*ii+jj;