metal : support GGML_OP_SET (#19548)
This commit is contained in:
parent
3bb78133ab
commit
490eb96b88
|
|
@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
|
|
|||
|
|
@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_set(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
|
@ -1609,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
||||
const size_t offs = ((const int32_t *) op->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
||||
|
||||
int64_t nk0 = ne10;
|
||||
if (ggml_is_quantized(op->src[1]->type)) {
|
||||
nk0 = ne10/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb10,
|
||||
/*.nb01 =*/ nb11,
|
||||
/*.nb02 =*/ nb12,
|
||||
/*.nb03 =*/ nb13,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ ggml_element_size(op),
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
bid_dst.offs += offs;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -2786,9 +2786,10 @@ struct test_set : public test_case {
|
|||
const ggml_type type_dst;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int dim;
|
||||
const bool inplace;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type_src, type_dst, ne, dim);
|
||||
return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
|
|
@ -2796,8 +2797,8 @@ struct test_set : public test_case {
|
|||
}
|
||||
|
||||
test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
|
||||
|
|
@ -2808,7 +2809,7 @@ struct test_set : public test_case {
|
|||
for (int i = 0; i < dim; ++i) {
|
||||
ne_dst[i] *= 2;
|
||||
}
|
||||
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_set_param(dst);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
|
|
@ -2816,9 +2817,16 @@ struct test_set : public test_case {
|
|||
for (int i = 0; i < dim; ++i) {
|
||||
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
|
||||
}
|
||||
ggml_tensor * out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
ggml_tensor * out;
|
||||
if (inplace) {
|
||||
out = ggml_set_inplace(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
} else {
|
||||
out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
}
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -7428,11 +7436,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
// same-type copy
|
||||
|
|
|
|||
Loading…
Reference in New Issue