Move rms_norm to split row approach
This commit is contained in:
parent
483615da76
commit
27b893a6f8
|
|
@ -28,6 +28,7 @@
|
|||
/* Constants */
|
||||
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
|
||||
#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 64
|
||||
#define WEBGPU_NUM_PARAM_BUFS 100
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
|
|
@ -35,6 +36,9 @@
|
|||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable default
|
||||
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
||||
|
||||
/* End Constants */
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||
|
|
@ -257,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
|||
}),
|
||||
UINT64_MAX);
|
||||
} else {
|
||||
// existing callbacks, wait on them
|
||||
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
||||
// WebGPU implementations may limit the number of futures that can be waited on at once,
|
||||
// so wait in batches (64 is what Dawn supports).
|
||||
for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
|
||||
size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
|
||||
ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
|
||||
}
|
||||
ctx->callback_futures.clear();
|
||||
}
|
||||
}
|
||||
|
|
@ -727,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
|
|
@ -1311,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|||
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
||||
}
|
||||
|
||||
// The max workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
||||
// Workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->max_wg_size_x;
|
||||
constants[0].value = wg_size;
|
||||
return constants;
|
||||
}
|
||||
|
||||
|
|
@ -1383,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
|
||||
"get_rows_f32_vec", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
|
||||
|
|
@ -1437,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
|
||||
|
|
@ -1449,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
|
||||
|
|
@ -1461,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
|
||||
|
|
@ -1473,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
|
||||
|
|
@ -1485,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
|
||||
|
|
@ -1497,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
|
||||
|
|
@ -1505,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
|
||||
"rope_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
|
||||
|
|
@ -1525,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
// reglu
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
|
||||
wgsl_reglu_f32, "reglu_f32", constants);
|
||||
|
|
@ -1579,7 +1585,7 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
|
||||
|
|
@ -1587,9 +1593,7 @@ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = 64;
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
|
||||
"soft_max_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
|
||||
|
|
|
|||
|
|
@ -71,14 +71,14 @@ var<storage, read_write> src: array<f32>;
|
|||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
// one thread per row
|
||||
var i = gid.x;
|
||||
var i = wid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
|
|
@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
|
||||
var sum = 0.0f;
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
sum += src[i_src_row + j] * src[i_src_row + j];
|
||||
var col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
workgroupBarrier();
|
||||
}
|
||||
sum = scratch[0];
|
||||
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
update(i_src_row + j, i_dst_row + j, scale);
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
col += wg_size;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
|
|
|||
|
|
@ -276,15 +276,17 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
var cache: array<f32, CACHE_SIZE>;
|
||||
|
||||
var max_val = lower_max_bound(i2);
|
||||
var col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
let col = j * wg_size + lid.x;
|
||||
if (col < params.ne0) {
|
||||
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
|
||||
max_val = max(max_val, val);
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = val;
|
||||
}
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
|
||||
max_val = max(max_val, val);
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = val;
|
||||
}
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = max_val;
|
||||
|
|
@ -300,19 +302,21 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
let row_max = scratch[0];
|
||||
|
||||
var sum = 0.0f;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
let col = j * wg_size + lid.x;
|
||||
if (col < params.ne0) {
|
||||
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
|
||||
cache[col], col < CACHE_SIZE);
|
||||
let ex = exp(val - row_max);
|
||||
sum += ex;
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = ex;
|
||||
} else {
|
||||
update(i_dst_row + col, ex);
|
||||
}
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
|
||||
cache[col], col < CACHE_SIZE);
|
||||
let ex = exp(val - row_max);
|
||||
sum += ex;
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = ex;
|
||||
} else {
|
||||
update(i_dst_row + col, ex);
|
||||
}
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
|
|
@ -328,11 +332,13 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
let row_sum = add_sinks(scratch[0], i2, row_max);
|
||||
|
||||
let sum_recip = 1.0 / row_sum;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
let col = j * wg_size + lid.x;
|
||||
if (col < params.ne0) {
|
||||
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
|
||||
col += wg_size;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
|
|
|||
Loading…
Reference in New Issue