Remove pipeline cache mutexes (#19195)

* Remove mutex for pipeline caches, since they are now per-thread.

* Add comment

* Run clang-format

* Cleanup

* Run CI again

* Run CI once more

* Run clang-format
This commit is contained in:
Nikhil Jain 2026-02-01 18:47:29 -08:00 committed by GitHub
parent 3bc8d2cf23
commit 2dc3ce2166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 93 additions and 120 deletions

View File

@ -146,8 +146,13 @@ struct webgpu_submission_futures {
struct webgpu_buf_pool { struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free; std::vector<webgpu_pool_bufs> free;
std::mutex mutex; // The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
// since allocating a pool per ggml buffer would consume too much memory.
// 2. For the per-thread buffer pools in webgpu_context,
// buffers are allocated and freed in Dawn callbacks,
// which can run on a different thread than the calling thread.
std::mutex mutex;
std::condition_variable cv; std::condition_variable cv;
void init(wgpu::Device device, void init(wgpu::Device device,
@ -266,7 +271,7 @@ struct webgpu_command {
#endif #endif
}; };
struct webgpu_capabilities_base { struct webgpu_capabilities {
wgpu::Limits limits; wgpu::Limits limits;
bool supports_subgroup_matrix = false; bool supports_subgroup_matrix = false;
@ -286,11 +291,11 @@ struct webgpu_global_context_struct {
wgpu::Device device; wgpu::Device device;
wgpu::Queue queue; wgpu::Queue queue;
webgpu_capabilities_base capabilities; webgpu_capabilities capabilities;
// Shared buffer to move data from device to host // Shared buffer to move data from device to host
wgpu::Buffer get_tensor_staging_buf; wgpu::Buffer get_tensor_staging_buf;
// Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
std::recursive_mutex mutex; std::recursive_mutex mutex;
webgpu_buf_pool memset_buf_pool; webgpu_buf_pool memset_buf_pool;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
@ -361,7 +366,6 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines; std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
size_t memset_bytes_per_thread; size_t memset_bytes_per_thread;
}; };
typedef std::shared_ptr<webgpu_context_struct> webgpu_context; typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@ -383,9 +387,8 @@ struct ggml_backend_webgpu_device_context {
// Per-thread data required to actually run WebGPU operations in a backend instance // Per-thread data required to actually run WebGPU operations in a backend instance
struct ggml_backend_webgpu_context { struct ggml_backend_webgpu_context {
webgpu_context webgpu_ctx; webgpu_context webgpu_ctx;
std::once_flag init_once; std::string name;
std::string name;
}; };
// Per-thread data related to buffers // Per-thread data related to buffers
@ -861,20 +864,15 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
{ auto it = ctx->pad_pipelines.find(pipeline_key);
// TODO: remove guard once pipeline caches are per-thread if (it != ctx->pad_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->pad_pipelines.find(pipeline_key); } else {
if (it != ctx->pad_pipelines.end()) { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
pipeline = it->second; pipeline =
} else { ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_processed_shader processed = pipeline.context = processed.decisions;
ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); ctx->pad_pipelines.emplace(pipeline_key, pipeline);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
} }
ggml_webgpu_generic_shader_decisions decisions = ggml_webgpu_generic_shader_decisions decisions =
@ -944,20 +942,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread auto it = ctx->set_rows_pipelines.find(key);
{ if (it != ctx->set_rows_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->set_rows_pipelines.find(key); } else {
if (it != ctx->set_rows_pipelines.end()) { ggml_webgpu_processed_shader processed =
pipeline = it->second; ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
} else { pipeline =
ggml_webgpu_processed_shader processed = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); pipeline.context = processed.decisions;
pipeline = ctx->set_rows_pipelines.emplace(key, pipeline);
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->set_rows_pipelines.emplace(key, pipeline);
}
} }
ggml_webgpu_generic_shader_decisions decisions = ggml_webgpu_generic_shader_decisions decisions =
@ -1261,29 +1255,25 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread auto it = ctx->flash_attn_pipelines.find(key);
{ if (it != ctx->flash_attn_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->flash_attn_pipelines.find(key); } else {
if (it != ctx->flash_attn_pipelines.end()) { ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
pipeline = it->second; .key = key,
} else { .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.key = key, .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, };
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
};
ggml_webgpu_processed_shader processed = ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline = pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions; pipeline.context = processed.decisions;
ctx->flash_attn_pipelines.emplace(key, pipeline); ctx->flash_attn_pipelines.emplace(key, pipeline);
}
} }
ggml_webgpu_flash_attn_shader_decisions decisions = ggml_webgpu_flash_attn_shader_decisions decisions =
@ -1308,20 +1298,16 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
{ auto it = ctx->unary_pipelines.find(pipeline_key);
// TODO: remove guard once pipeline caches are per-thread if (it != ctx->unary_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->unary_pipelines.find(pipeline_key); } else {
if (it != ctx->unary_pipelines.end()) { ggml_webgpu_processed_shader processed =
pipeline = it->second; ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
} else { pipeline =
ggml_webgpu_processed_shader processed = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); pipeline.context = processed.decisions;
pipeline = ctx->unary_pipelines.emplace(pipeline_key, pipeline);
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
} }
ggml_webgpu_generic_shader_decisions decisions = ggml_webgpu_generic_shader_decisions decisions =
@ -1743,19 +1729,15 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
{ auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
// TODO: remove guard once pipeline caches are per-thread if (it != ctx->argmax_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); } else {
if (it != ctx->argmax_pipelines.end()) { ggml_webgpu_processed_shader processed =
pipeline = it->second; ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
} else { pipeline =
ggml_webgpu_processed_shader processed = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
}
} }
uint32_t wg_x = ggml_nelements(dst); uint32_t wg_x = ggml_nelements(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@ -1772,9 +1754,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
.order = order .order = order
}; };
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); webgpu_pipeline argsort_pipeline;
webgpu_pipeline argsort_pipeline; auto it = ctx->argsort_pipelines.find(order);
auto it = ctx->argsort_pipelines.find(order);
if (it != ctx->argsort_pipelines.end()) { if (it != ctx->argsort_pipelines.end()) {
argsort_pipeline = it->second; argsort_pipeline = it->second;
} else { } else {
@ -1963,19 +1944,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread auto it = ctx->cumsum_pipelines.find(1);
{ if (it != ctx->cumsum_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->cumsum_pipelines.find(1); } else {
if (it != ctx->cumsum_pipelines.end()) { ggml_webgpu_processed_shader processed =
pipeline = it->second; ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
} else { pipeline =
ggml_webgpu_processed_shader processed = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); ctx->cumsum_pipelines.emplace(1, pipeline);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->cumsum_pipelines.emplace(1, pipeline);
}
} }
uint32_t wg_x = ggml_nrows(dst); uint32_t wg_x = ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@ -2009,19 +1986,15 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
}; };
webgpu_pipeline pipeline; webgpu_pipeline pipeline;
{ auto it = ctx->sum_rows_pipelines.find(1);
// TODO: remove guard once pipeline caches are per-thread if (it != ctx->sum_rows_pipelines.end()) {
std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex); pipeline = it->second;
auto it = ctx->sum_rows_pipelines.find(1); } else {
if (it != ctx->sum_rows_pipelines.end()) { ggml_webgpu_processed_shader processed =
pipeline = it->second; ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
} else { pipeline =
ggml_webgpu_processed_shader processed = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); ctx->sum_rows_pipelines.emplace(1, pipeline);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->sum_rows_pipelines.emplace(1, pipeline);
}
} }
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
@ -3016,10 +2989,10 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#ifdef GGML_WEBGPU_GPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE
// Initialize buffer pool for timestamp queries, used for profiling // Initialize buffer pool for timestamp queries, used for profiling
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
#endif #endif
GGML_LOG_INFO( GGML_LOG_INFO(