metal : use params per pipeline instance (#17739)

This commit is contained in:
Georgi Gerganov 2025-12-04 10:34:11 +02:00 committed by GitHub
parent a67ef0f47f
commit 0d1324856f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 509 additions and 619 deletions

File diff suppressed because it is too large Load Diff

View File

@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
ggml_metal_pipeline_t ggml_metal_pipeline_init(void); ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
// a collection of pipelines // a collection of pipelines
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
struct ggml_metal_pipeline_with_params {
ggml_metal_pipeline_t pipeline;
int nsg;
int nr0;
int nr1;
size_t smem;
};
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
// //
// MTLCommandBuffer wrapper // MTLCommandBuffer wrapper
// //
@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
@ -100,66 +99,66 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
void ggml_metal_library_free(ggml_metal_library_t lib); void ggml_metal_library_free(ggml_metal_library_t lib);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,
bool has_mask, bool has_mask,
int32_t ncpsg); int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,
int32_t nqptg, int32_t nqptg,
int32_t ncpsg); int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,
bool has_mask, bool has_mask,
@ -169,7 +168,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_kvpad, bool has_kvpad,
int32_t nsg); int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,
bool has_mask, bool has_mask,
@ -180,7 +179,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
int32_t nsg, int32_t nsg,
int32_t nwg); int32_t nwg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,
int32_t dv, int32_t dv,

View File

@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
struct ggml_metal_pipeline { struct ggml_metal_pipeline {
id<MTLComputePipelineState> obj; id<MTLComputePipelineState> obj;
// suggested dispatch sizes
int nsg;
int nr0;
int nr1;
size_t smem;
}; };
ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
*res = (struct ggml_metal_pipeline) { *res = (struct ggml_metal_pipeline) {
/*.obj =*/ nil, /*.obj =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.smem =*/ 0,
}; };
return res; return res;
@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
free(pipeline); free(pipeline);
} }
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
pipeline->nsg = nsg; return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
return pipeline->nsg;
}
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
pipeline->nr0 = nr0;
}
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
return pipeline->nr0;
}
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
pipeline->nr1 = nr1;
}
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
return pipeline->nr1;
}
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
pipeline->smem = smem;
}
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
return pipeline->smem;
}
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
return pipeline->obj.maxTotalThreadsPerThreadgroup;
} }
struct ggml_metal_library { struct ggml_metal_library {
@ -389,28 +345,42 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
free(lib); free(lib);
} }
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
[lib->lock lock]; [lib->lock lock];
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
};
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
[lib->lock unlock]; [lib->lock unlock];
return res; return res;
} }
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
};
[lib->lock lock]; [lib->lock lock];
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
if (res) { if (res.pipeline) {
[lib->lock unlock]; [lib->lock unlock];
return res; return res;
} }
res = ggml_metal_pipeline_init();
@autoreleasepool { @autoreleasepool {
NSError * error = nil; NSError * error = nil;
@ -432,26 +402,43 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
} }
return nil; return res;
} }
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
[mtl_function release]; [mtl_function release];
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, if (!obj) {
(int) res->obj.maxTotalThreadsPerThreadgroup, [lib->lock unlock];
(int) res->obj.threadExecutionWidth);
GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
if (error) {
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
}
return res;
}
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
(void *) obj,
(int) obj.maxTotalThreadsPerThreadgroup,
(int) obj.threadExecutionWidth);
if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
[obj release];
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
[lib->lock unlock]; [lib->lock unlock];
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
return nil; return res;
} }
ggml_metal_pipelines_add(lib->pipelines, name, res); res.pipeline = ggml_metal_pipeline_init();
res.pipeline->obj = obj;
ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
} }
[lib->lock unlock]; [lib->lock unlock];
@ -496,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
[encoder->obj popDebugGroup]; [encoder->obj popDebugGroup];
} }
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
[encoder->obj setComputePipelineState:pipeline->obj]; [encoder->obj setComputePipelineState:pipeline.pipeline->obj];
} }
void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
@ -622,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
dev->props.has_tensor = false; dev->props.has_tensor = false;
} else { } else {
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl) { if (!ppl.pipeline) {
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
dev->props.has_tensor = false; dev->props.has_tensor = false;
} }
@ -672,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
dev->props.has_bfloat = false; dev->props.has_bfloat = false;
} else { } else {
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
if (!ppl) { if (!ppl.pipeline) {
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
dev->props.has_bfloat = false; dev->props.has_bfloat = false;
} }

View File

@ -524,7 +524,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
/*.dim =*/ dim, /*.dim =*/ dim,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -550,7 +550,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
ggml_metal_kargs_repeat args = { ggml_metal_kargs_repeat args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -616,7 +616,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
// TODO: make a simpler cpy_bytes kernel // TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = { ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00, /*.nk0 =*/ ne00,
@ -679,7 +679,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.o1 =*/ { 0 }, /*.o1 =*/ { 0 },
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -721,7 +721,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
n /= 4; n /= 4;
} }
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -760,7 +760,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
n /= 4; n /= 4;
} }
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -789,7 +789,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
n /= 4; n /= 4;
} }
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
@ -817,7 +817,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
} }
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
const int32_t swp = ggml_get_op_params_i32(op, 1); const int32_t swp = ggml_get_op_params_i32(op, 1);
const float alpha = ggml_get_op_params_f32(op, 2); const float alpha = ggml_get_op_params_f32(op, 2);
@ -870,7 +870,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
/*.np =*/ n, /*.np =*/ n,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
int nth = 32; // SIMD width int nth = 32; // SIMD width
@ -925,7 +925,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3, /*.nb3 =*/ nb3,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
int nth = 32; // SIMD width int nth = 32; // SIMD width
@ -936,7 +936,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00); nth = std::min(nth, ne00);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -963,7 +963,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
int nth = 1; int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
@ -1060,7 +1060,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx); ggml_metal_op_concurrency_reset(ctx);
{ {
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
ggml_metal_kargs_cumsum_add args = { ggml_metal_kargs_cumsum_add args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -1106,7 +1106,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = { ggml_metal_kargs_get_rows args = {
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
@ -1151,7 +1151,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
const int32_t nk0 = ne0/ggml_blck_size(op->type); const int32_t nk0 = ne0/ggml_blck_size(op->type);
@ -1252,7 +1252,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
/*.n_head_log2 =*/ n_head_log2, /*.n_head_log2 =*/ n_head_log2,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
int nth = 32; // SIMD width int nth = 32; // SIMD width
@ -1266,7 +1266,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
} }
} }
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@ -1322,7 +1322,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2, /*.nb2 =*/ nb2,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@ -1409,11 +1409,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.nb0 =*/ nb0, /*.nb0 =*/ nb0,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -1426,7 +1426,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
@ -1449,7 +1449,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
const int64_t C = op->ne[0]; const int64_t C = op->ne[0];
const int64_t H = op->src[0]->ne[1]; const int64_t H = op->src[0]->ne[1];
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
int ida = 0; int ida = 0;
@ -1485,7 +1485,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
@ -1592,7 +1592,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
/* .np = */ np /* .np = */ np
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
const int ntg = (np + nth - 1) / nth; const int ntg = (np + nth - 1) / nth;
@ -1701,7 +1701,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11"); GGML_ABORT("unsupported ne11");
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
ggml_metal_kargs_mul_mv_ext args = { ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -1748,7 +1748,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
// default: break; // default: break;
//} //}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
ggml_metal_kargs_mul_mm args = { ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -1773,18 +1773,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
} else { } else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); const int nr0 = pipeline.nr0;
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); const int nr1 = pipeline.nr1;
const int nsg = ggml_metal_pipeline_get_nsg(pipeline); const int nsg = pipeline.nsg;
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_kargs_mul_mv args = { ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -1915,9 +1915,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
nb21, nb21,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@ -1938,7 +1938,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx); ggml_metal_op_concurrency_reset(ctx);
{ {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
ggml_metal_kargs_mul_mm_id args = { ggml_metal_kargs_mul_mm_id args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -1967,20 +1967,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_ids, 4); ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
ggml_metal_encoder_set_buffer (enc, bid_dst, 5); ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
} }
} else { } else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); const int nr0 = pipeline.nr0;
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); const int nr1 = pipeline.nr1;
const int nsg = ggml_metal_pipeline_get_nsg(pipeline); const int nsg = pipeline.nsg;
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_kargs_mul_mv_id args = { ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20, /*.nei0 =*/ ne20,
@ -2064,7 +2064,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21, /*.nb21 =*/ nb21,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -2308,7 +2308,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/nb33, /*.nb33 =*/nb33,
}; };
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@ -2339,7 +2339,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/ nb33, /*.nb33 =*/ nb33,
}; };
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@ -2424,7 +2424,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap, /*.logit_softcap =*/ logit_softcap,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -2476,7 +2476,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb33 =*/nb33, /*.nb33 =*/nb33,
}; };
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@ -2578,7 +2578,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap, /*.logit_softcap =*/ logit_softcap,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@ -2630,7 +2630,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
nrows, nrows,
}; };
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@ -2762,7 +2762,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
bid_src1.offs = 0; bid_src1.offs = 0;
ggml_metal_pipeline_t pipeline = nullptr; struct ggml_metal_pipeline_with_params pipeline;
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@ -2835,7 +2835,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
/*.eps =*/ eps, /*.eps =*/ eps,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2; nth *= 2;
@ -2844,7 +2844,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4); nth = std::min(nth, ne00/4);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]); const int64_t nrows = ggml_nrows(op->src[0]);
@ -2887,7 +2887,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
/*.eps =*/ eps, /*.eps =*/ eps,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
int nth = 32; // SIMD width int nth = 32; // SIMD width
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@ -2897,7 +2897,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
//nth = std::min(nth, ne00/4); //nth = std::min(nth, ne00/4);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3022,7 +3022,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
} }
} }
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
int nth = 32; // SIMD width int nth = 32; // SIMD width
@ -3033,7 +3033,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, args.ne00_t); nth = std::min(nth, args.ne00_t);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3127,7 +3127,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* src2 =*/ op->src[2] != nullptr, /* src2 =*/ op->src[2] != nullptr,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3199,7 +3199,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
/*.KHW =*/ KH * KW, /*.KHW =*/ KH * KW,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@ -3270,7 +3270,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
/*.d1 =*/ d1, /*.d1 =*/ d1,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nth = std::min(nth, 256); nth = std::min(nth, 256);
@ -3325,7 +3325,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
/*.nb1 =*/ nb1, /*.nb1 =*/ nb1,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3377,7 +3377,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2, /*.nb2 =*/ nb2,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3433,7 +3433,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
/*.sf3 =*/ sf3 /*.sf3 =*/ sf3
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
@ -3477,7 +3477,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3 /*.nb3 =*/ nb3
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
const int nth = std::min(1024, ne0); const int nth = std::min(1024, ne0);
@ -3523,7 +3523,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
/*.p1 =*/ ((const int32_t *)(op->op_params))[1] /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
const int nth = std::min(1024, ne0); const int nth = std::min(1024, ne0);
@ -3560,7 +3560,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
const int nth = std::min(1024, ne0); const int nth = std::min(1024, ne0);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3591,7 +3591,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
/*.max_period =*/ max_period, /*.max_period =*/ max_period,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
const int nth = std::max(1, std::min(1024, dim/2)); const int nth = std::max(1, std::min(1024, dim/2));
@ -3621,7 +3621,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
/*.nb01 = */ nb01, /*.nb01 = */ nb01,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
const int64_t nrows = ggml_nrows(op->src[0]); const int64_t nrows = ggml_nrows(op->src[0]);
@ -3630,7 +3630,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
nth *= 2; nth *= 2;
} }
const size_t smem = ggml_metal_pipeline_get_smem(pipeline); const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -3657,7 +3657,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
// bitonic sort requires the number of elements to be power of 2 // bitonic sort requires the number of elements to be power of 2
int nth = 1; int nth = 1;
@ -3706,7 +3706,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
int len = nth; int len = nth;
@ -3764,7 +3764,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2 // bitonic sort requires the number of elements to be power of 2
int nth = 1; int nth = 1;
@ -3818,7 +3818,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k; int len = args.top_k;
@ -3881,7 +3881,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
/*.slope =*/ slope /*.slope =*/ slope
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op); int64_t n = ggml_nelements(op);
@ -3910,7 +3910,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]); const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = { ggml_metal_kargs_opt_step_adamw args = {
@ -3946,7 +3946,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]); const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = { ggml_metal_kargs_opt_step_sgd args = {