sycl: fix norm kernels: l2_norm, group_norm, rms_norm by remove assert to support more cases (#19154)

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
This commit is contained in:
Neo Zhang 2026-01-29 09:20:22 +08:00 committed by GitHub
parent 50e8962f79
commit d4964a7c66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 753 additions and 600 deletions

File diff suppressed because it is too large Load Diff

View File

@ -4606,14 +4606,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return true;
case GGML_OP_RMS_NORM_BACK:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:

View File

@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);