diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2669a3f95c..324a5851c6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -129,6 +129,8 @@ struct webgpu_context_struct { wgpu::ComputePipeline add_ip_pipeline[2]; wgpu::ComputePipeline mul_pipeline[2]; wgpu::ComputePipeline mul_ip_pipeline[2]; + wgpu::ComputePipeline rms_norm_pipeline; + wgpu::ComputePipeline rms_norm_ip_pipeline; size_t memset_bytes_per_thread; @@ -640,6 +642,56 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } +static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool in_place = ggml_webgpu_tensor_equal(src, dst); + + uint32_t eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + }; + if (!in_place) { + params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); + } + params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); + params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); + params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); + if (!in_place) { + params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); + params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); + params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); + } + params.push_back((uint32_t) src->ne[0]); + params.push_back((uint32_t) src->ne[1]); + params.push_back((uint32_t) src->ne[2]); + params.push_back((uint32_t) src->ne[3]); + params.push_back(eps); // epsilon, will be bitcast to float in shader + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!in_place) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline; + if (in_place) { + pipeline = ctx->rms_norm_ip_pipeline; + } else { + pipeline = ctx->rms_norm_pipeline; + } + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -691,6 +743,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { } break; } + case GGML_OP_RMS_NORM: + { + ggml_webgpu_rms_norm(ctx, src0, node); + break; + } default: return false; } @@ -947,6 +1004,14 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } +// The max workgroup size is a common constant +static std::vector max_wg_size_entry(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + return constants; +} + static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; @@ -1010,24 +1075,16 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", - constants); + max_wg_size_entry(webgpu_ctx)); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", max_wg_size_entry(webgpu_ctx)); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + std::vector constants = max_wg_size_entry(webgpu_ctx); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", @@ -1039,9 +1096,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + std::vector constants = max_wg_size_entry(webgpu_ctx); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", @@ -1052,6 +1107,14 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { "mul_in_place_f16", constants); } +static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, + "rms_norm_in_place", constants); +} + static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); @@ -1158,6 +1221,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } } + case GGML_OP_RMS_NORM: + supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; default: break; } @@ -1282,6 +1347,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_rms_norm_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl new file mode 100644 index 0000000000..f919a51336 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -0,0 +1,57 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src/dst + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne1 * params.ne2 * params.ne3) { + return; + } + + // one thread per row + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + var sum = 0.0f; + for (var j: u32 = 0; j < params.ne0; j++) { + sum += src[i_src_row + j] * src[i_src_row + j]; + } + let eps = bitcast(params.eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + for (var j: u32 = 0; j < params.ne0; j++) { + dst[i_dst_row + j] = scale * src[i_src_row + j]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl new file mode 100644 index 0000000000..ae84f556d6 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl @@ -0,0 +1,48 @@ +@group(0) @binding(0) +var a: array; + +struct Params { + offset: u32, // in elements + + // Strides (in elements) + stride1: u32, + stride2: u32, + stride3: u32, + + // Shape + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: u32 +}; + +@group(0) @binding(1) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne1 * params.ne2 * params.ne3) { + return; + } + + // one thread per row + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; + + var sum = 0.0f; + for (var j: u32 = 0; j < params.ne0; j++) { + sum += a[i_row + j] * a[i_row + j]; + } + let eps = bitcast(params.eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + for (var j: u32 = 0; j < params.ne0; j++) { + a[i_row + j] = scale * a[i_row + j]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 4bd6f94a23..3567713dc2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -52,7 +52,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } var i = gid.x; let i_src3 = i / (params.ne2 * params.n_rows); - let i_dst3 = i / (params.ne2 * 3); i = i % (params.ne2 * params.n_rows); let i_src2 = i / params.n_rows;