Implement rms_norm

This commit is contained in:
Reese Levine 2025-09-09 17:10:23 -07:00
parent efc0cb0c28
commit 7fbe84cd5f
4 changed files with 185 additions and 15 deletions

View File

@ -129,6 +129,8 @@ struct webgpu_context_struct {
wgpu::ComputePipeline add_ip_pipeline[2];
wgpu::ComputePipeline mul_pipeline[2];
wgpu::ComputePipeline mul_ip_pipeline[2];
wgpu::ComputePipeline rms_norm_pipeline;
wgpu::ComputePipeline rms_norm_ip_pipeline;
size_t memset_bytes_per_thread;
@ -640,6 +642,56 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool in_place = ggml_webgpu_tensor_equal(src, dst);
uint32_t eps;
memcpy(&eps, dst->op_params, sizeof(float));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
};
if (!in_place) {
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
if (!in_place) {
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) src->ne[0]);
params.push_back((uint32_t) src->ne[1]);
params.push_back((uint32_t) src->ne[2]);
params.push_back((uint32_t) src->ne[3]);
params.push_back(eps); // epsilon, will be bitcast to float in shader
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
};
if (!in_place) {
entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
wgpu::ComputePipeline pipeline;
if (in_place) {
pipeline = ctx->rms_norm_ip_pipeline;
} else {
pipeline = ctx->rms_norm_pipeline;
}
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
@ -691,6 +743,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
}
break;
}
case GGML_OP_RMS_NORM:
{
ggml_webgpu_rms_norm(ctx, src0, node);
break;
}
default:
return false;
}
@ -947,6 +1004,14 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
// The max workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> max_wg_size_entry(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
return constants;
}
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
@ -1010,24 +1075,16 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
constants);
max_wg_size_entry(webgpu_ctx));
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", max_wg_size_entry(webgpu_ctx));
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
std::vector<wgpu::ConstantEntry> constants = max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
@ -1039,9 +1096,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
std::vector<wgpu::ConstantEntry> constants = max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
@ -1052,6 +1107,14 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
"mul_in_place_f16", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
"rms_norm_in_place", constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@ -1158,6 +1221,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
}
case GGML_OP_RMS_NORM:
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
default:
break;
}
@ -1282,6 +1347,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_mul_pipeline(ctx);
ggml_webgpu_init_rms_norm_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers

View File

@ -0,0 +1,57 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Shape of src/dst
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
eps: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
// one thread per row
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += src[i_src_row + j] * src[i_src_row + j];
}
let eps = bitcast<f32>(params.eps);
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
for (var j: u32 = 0; j < params.ne0; j++) {
dst[i_dst_row + j] = scale * src[i_src_row + j];
}
}

View File

@ -0,0 +1,48 @@
@group(0) @binding(0)
var<storage, read_write> a: array<f32>;
struct Params {
offset: u32, // in elements
// Strides (in elements)
stride1: u32,
stride2: u32,
stride3: u32,
// Shape
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
eps: u32
};
@group(0) @binding(1)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
// one thread per row
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += a[i_row + j] * a[i_row + j];
}
let eps = bitcast<f32>(params.eps);
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
for (var j: u32 = 0; j < params.ne0; j++) {
a[i_row + j] = scale * a[i_row + j];
}
}

View File

@ -52,7 +52,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
}
var i = gid.x;
let i_src3 = i / (params.ne2 * params.n_rows);
let i_dst3 = i / (params.ne2 * 3);
i = i % (params.ne2 * params.n_rows);
let i_src2 = i / params.n_rows;