metal : add opt_step_adamw and op_sum (#16529)
* scaffold to support opt step adamw on metal (not written so far) * add opt-step-adamw kernel for metal * pass op->src[4] as a separate buffer to the pipeline * add bounds check to opt-step-adamw kernel * complete scaffold for GGML_OP_SUM * naive GGML_OP_SUM kernel * remove unwanted comment * change OP_SUM capability gate * Add has_simdgroup_reduction to both ops to pass CI
This commit is contained in:
parent
81d54bbfd5
commit
a31cf36ad9
|
|
@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_SUM);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||||
|
|
||||||
|
|
@ -1482,3 +1501,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
@ -134,6 +135,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||||
ggml_metal_library_t lib,
|
ggml_metal_library_t lib,
|
||||||
|
|
|
||||||
|
|
@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
|
@ -798,6 +799,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
return has_simdgroup_reduction;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -544,6 +544,10 @@ typedef struct{
|
||||||
float limit;
|
float limit;
|
||||||
} ggml_metal_kargs_glu;
|
} ggml_metal_kargs_glu;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
uint64_t np;
|
||||||
|
} ggml_metal_kargs_sum;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
int64_t ne01;
|
int64_t ne01;
|
||||||
|
|
@ -773,4 +777,8 @@ typedef struct {
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
} ggml_metal_kargs_argmax;
|
} ggml_metal_kargs_argmax;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int64_t np;
|
||||||
|
} ggml_metal_kargs_opt_step_adamw;
|
||||||
|
|
||||||
#endif // GGML_METAL_IMPL
|
#endif // GGML_METAL_IMPL
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_glu(ctx, idx);
|
n_fuse = ggml_metal_op_glu(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SUM:
|
||||||
|
{
|
||||||
|
n_fuse = ggml_metal_op_sum(ctx, idx);
|
||||||
|
} break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
{
|
{
|
||||||
|
|
@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
{
|
||||||
|
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||||
|
|
@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
||||||
|
|
||||||
|
ggml_metal_kargs_sum args = {
|
||||||
|
/*.np =*/ n,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_tensor * op = ctx->node(idx);
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
|
@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||||
|
|
||||||
|
const int64_t np = ggml_nelements(op->src[0]);
|
||||||
|
ggml_metal_kargs_opt_step_adamw args = {
|
||||||
|
/*.np =*/ np,
|
||||||
|
};
|
||||||
|
|
||||||
|
int ida = 0;
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
||||||
|
|
||||||
|
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||||
|
const int64_t n = (np + nth - 1) / nth;
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
@ -78,6 +79,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_op_sum_f32(
|
||||||
|
constant ggml_metal_kargs_sum & args,
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
if (tiitg != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (ulong i = 0; i < args.np; ++i) {
|
||||||
|
acc += src0[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[0] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
constant ggml_metal_kargs_sum_rows & args,
|
constant ggml_metal_kargs_sum_rows & args,
|
||||||
|
|
@ -8754,3 +8772,37 @@ kernel void kernel_pool_2d_avg_f32(
|
||||||
|
|
||||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_opt_step_adamw_f32(
|
||||||
|
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||||
|
device float * x,
|
||||||
|
device const float * g,
|
||||||
|
device float * g_m,
|
||||||
|
device float * g_v,
|
||||||
|
device const float * pars,
|
||||||
|
uint gid[[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
if (gid >= args.np) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float alpha = pars[0];
|
||||||
|
const float beta1 = pars[1];
|
||||||
|
const float beta2 = pars[2];
|
||||||
|
const float eps = pars[3];
|
||||||
|
const float wd = pars[4];
|
||||||
|
const float beta1h = pars[5];
|
||||||
|
const float beta2h = pars[6];
|
||||||
|
|
||||||
|
const float gi = g[gid];
|
||||||
|
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||||
|
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||||
|
|
||||||
|
g_m[gid] = gmi;
|
||||||
|
g_v[gid] = gvi;
|
||||||
|
|
||||||
|
const float mh = gmi * beta1h;
|
||||||
|
const float vh = sqrt(gvi * beta2h) + eps;
|
||||||
|
|
||||||
|
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue