ggml webgpu: actually add softmax, fix rms_norm offset (#16400)
* implement soft_max * Fix soft_max data race * Temporary fix, wait on each submit
This commit is contained in:
parent
86df2c9ae4
commit
35266573b9
|
|
@ -424,6 +424,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
||||||
ctx->staged_param_bufs.push_back(params_bufs);
|
ctx->staged_param_bufs.push_back(params_bufs);
|
||||||
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
||||||
ggml_backend_webgpu_submit_queue(ctx);
|
ggml_backend_webgpu_submit_queue(ctx);
|
||||||
|
ggml_backend_webgpu_wait_on_submission(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1060,6 +1061,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
ggml_webgpu_scale(ctx, src0, node);
|
ggml_webgpu_scale(ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -1806,6 +1810,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
supports_op = op->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -1949,6 +1956,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||||
ggml_webgpu_init_rope_pipeline(ctx);
|
ggml_webgpu_init_rope_pipeline(ctx);
|
||||||
ggml_webgpu_init_glu_pipeline(ctx);
|
ggml_webgpu_init_glu_pipeline(ctx);
|
||||||
ggml_webgpu_init_scale_pipeline(ctx);
|
ggml_webgpu_init_scale_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_soft_max_pipeline(ctx);
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_DEBUG
|
#ifdef GGML_WEBGPU_DEBUG
|
||||||
// Initialize debug buffers
|
// Initialize debug buffers
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||||
let i2 = i / params.ne1;
|
let i2 = i / params.ne1;
|
||||||
let i1 = i % params.ne1;
|
let i1 = i % params.ne1;
|
||||||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||||
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||||
|
|
||||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||||
workgroupBarrier();
|
workgroupBarrier();
|
||||||
}
|
}
|
||||||
let row_max = scratch[0];
|
let row_max = scratch[0];
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
var sum = 0.0f;
|
var sum = 0.0f;
|
||||||
col = lid.x;
|
col = lid.x;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue