metal : add env var to trigger graph capture (#20398)
This commit is contained in:
parent
ecac98ee53
commit
c363256839
|
|
@ -47,7 +47,7 @@ struct ggml_metal {
|
|||
uint64_t fuse_cnt[GGML_OP_COUNT];
|
||||
|
||||
// capture state
|
||||
bool capture_next_compute;
|
||||
int capture_compute;
|
||||
bool capture_started;
|
||||
|
||||
id<MTLCaptureScope> capture_scope;
|
||||
|
|
@ -158,10 +158,17 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
|||
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
|
||||
|
||||
res->capture_next_compute = false;
|
||||
res->capture_compute = 0;
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE");
|
||||
if (val) {
|
||||
res->capture_compute = atoi(val);
|
||||
}
|
||||
}
|
||||
|
||||
res->has_error = false;
|
||||
|
||||
res->gf = nil;
|
||||
|
|
@ -458,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
|
||||
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
||||
|
||||
const bool use_capture = ctx->capture_next_compute;
|
||||
if (ctx->capture_compute > 0) {
|
||||
ctx->capture_compute--;
|
||||
}
|
||||
|
||||
const bool use_capture = ctx->capture_compute == 0;
|
||||
if (use_capture) {
|
||||
ctx->capture_next_compute = false;
|
||||
ctx->capture_compute = -1;
|
||||
|
||||
// make sure all previous computations have finished before starting the capture
|
||||
if (ctx->cmd_buf_last) {
|
||||
|
|
@ -469,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
}
|
||||
|
||||
if (!ctx->capture_started) {
|
||||
NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()];
|
||||
|
||||
GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]);
|
||||
|
||||
// create capture scope
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
|
||||
|
|
@ -476,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||
descriptor.captureObject = ctx->capture_scope;
|
||||
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:path];
|
||||
|
||||
NSError * error = nil;
|
||||
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
||||
|
|
@ -683,7 +698,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
|||
idx_end,
|
||||
ctx->use_fusion,
|
||||
ctx->use_concurrency,
|
||||
ctx->capture_next_compute,
|
||||
ctx->capture_compute,
|
||||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
|
|
@ -718,5 +733,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
|||
}
|
||||
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
||||
ctx->capture_next_compute = true;
|
||||
ctx->capture_compute = 1;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue