metal : fuse NORM + MUL + ADD, support non-multiples of 4 (#16220)
* metal : fuse NORM + MUL + ADD * metal : support norms of non-multiple of 4 * cont : fix comment [no ci]
This commit is contained in:
parent
4ea00794b8
commit
dfcd53f7ec
|
|
@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
|
||||||
// fuse only ops that start with these operations
|
// fuse only ops that start with these operations
|
||||||
// can be expanded when needed
|
// can be expanded when needed
|
||||||
if (node.op() == GGML_OP_ADD ||
|
if (node.op() == GGML_OP_ADD ||
|
||||||
|
node.op() == GGML_OP_NORM ||
|
||||||
node.op() == GGML_OP_RMS_NORM) {
|
node.op() == GGML_OP_RMS_NORM) {
|
||||||
ops[0] = node.op();
|
ops[0] = node.op();
|
||||||
|
|
||||||
|
|
@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
|
||||||
// can be expanded when needed
|
// can be expanded when needed
|
||||||
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||||
gf->nodes[f]->op != GGML_OP_MUL &&
|
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_NORM &&
|
||||||
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
|
||||||
assert(op->op == GGML_OP_RMS_NORM);
|
|
||||||
|
|
||||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
||||||
|
|
||||||
char base[256];
|
|
||||||
char name[256];
|
|
||||||
|
|
||||||
switch (n_fuse) {
|
|
||||||
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
|
|
||||||
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
|
|
||||||
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
|
|
||||||
default: GGML_ABORT("fatal error");
|
|
||||||
}
|
|
||||||
|
|
||||||
snprintf(name, 256, "%s", base);
|
|
||||||
|
|
||||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
|
||||||
if (res) {
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
||||||
|
|
||||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
assert(op->op == GGML_OP_L2_NORM);
|
assert(op->op == GGML_OP_L2_NORM);
|
||||||
|
|
||||||
|
|
@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
|
||||||
assert(op->op == GGML_OP_NORM);
|
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
|
||||||
|
|
||||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
||||||
|
|
||||||
char base[256];
|
char base[256];
|
||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
snprintf(base, 256, "kernel_norm_f32");
|
const char * suffix = "";
|
||||||
|
if (op->ne[0] % 4 == 0) {
|
||||||
|
suffix = "_4";
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (op->op) {
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
switch (n_fuse) {
|
||||||
|
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
|
||||||
|
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
|
||||||
|
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
} break;
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
switch (n_fuse) {
|
||||||
|
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
|
||||||
|
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
|
||||||
|
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
} break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
|
||||||
snprintf(name, 256, "%s", base);
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
|
|
||||||
|
|
@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
return has_simdgroup_reduction;
|
return has_simdgroup_reduction;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
case GGML_OP_RMS_NORM:
|
||||||
|
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
|
|
||||||
|
|
@ -428,16 +428,11 @@ typedef struct {
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
} ggml_metal_kargs_mul_mv_id;
|
} ggml_metal_kargs_mul_mv_id;
|
||||||
|
|
||||||
|
// NORM
|
||||||
|
// RMS_NORM
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne00_4;
|
int32_t ne00_t;
|
||||||
uint64_t nb01;
|
|
||||||
float eps;
|
|
||||||
} ggml_metal_kargs_norm;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
int32_t ne00;
|
|
||||||
int32_t ne00_4;
|
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
uint64_t nb2;
|
uint64_t nb2;
|
||||||
uint64_t nb3;
|
uint64_t nb3;
|
||||||
|
|
@ -448,7 +443,7 @@ typedef struct {
|
||||||
uint64_t nbf1[3];
|
uint64_t nbf1[3];
|
||||||
uint64_t nbf2[3];
|
uint64_t nbf2[3];
|
||||||
uint64_t nbf3[3];
|
uint64_t nbf3[3];
|
||||||
} ggml_metal_kargs_rms_norm;
|
} ggml_metal_kargs_norm;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
|
|
|
||||||
|
|
@ -266,10 +266,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
{
|
|
||||||
n_fuse = ggml_metal_op_rms_norm(ctx, idx);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
||||||
|
|
@ -279,6 +275,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_norm(ctx, idx);
|
n_fuse = ggml_metal_op_norm(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -2346,146 +2343,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||||
return n_fuse;
|
return n_fuse;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
|
||||||
ggml_cgraph * gf = ctx->gf;
|
|
||||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
|
||||||
|
|
||||||
ggml_metal_library_t lib = ctx->lib;
|
|
||||||
ggml_metal_encoder_t enc = ctx->enc;
|
|
||||||
|
|
||||||
const int idx_end = ctx->idx_end;
|
|
||||||
|
|
||||||
const bool use_fusion = ctx->use_fusion;
|
|
||||||
|
|
||||||
const int debug_fusion = ctx->debug_fusion;
|
|
||||||
|
|
||||||
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
||||||
|
|
||||||
float eps;
|
|
||||||
memcpy(&eps, op->op_params, sizeof(float));
|
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
||||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
||||||
|
|
||||||
ggml_metal_kargs_rms_norm args = {
|
|
||||||
/*.ne00 =*/ ne00,
|
|
||||||
/*.ne00_4 =*/ ne00/4,
|
|
||||||
/*.nb1 =*/ nb1,
|
|
||||||
/*.nb2 =*/ nb2,
|
|
||||||
/*.nb3 =*/ nb3,
|
|
||||||
/*.eps =*/ eps,
|
|
||||||
/*.nef1 =*/ { ne01 },
|
|
||||||
/*.nef2 =*/ { ne02 },
|
|
||||||
/*.nef3 =*/ { ne03 },
|
|
||||||
/*.nbf1 =*/ { nb01 },
|
|
||||||
/*.nbf2 =*/ { nb02 },
|
|
||||||
/*.nbf3 =*/ { nb03 },
|
|
||||||
};
|
|
||||||
|
|
||||||
ggml_op fops[8];
|
|
||||||
|
|
||||||
int n_fuse = 1;
|
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
|
|
||||||
|
|
||||||
// d[0] = rms_norm(a)
|
|
||||||
// d[1] = mul(d[0], b)
|
|
||||||
// d[2] = add(d[1], c)
|
|
||||||
if (use_fusion) {
|
|
||||||
fops[0] = GGML_OP_RMS_NORM;
|
|
||||||
fops[1] = GGML_OP_MUL;
|
|
||||||
fops[2] = GGML_OP_ADD;
|
|
||||||
|
|
||||||
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
||||||
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
|
||||||
|
|
||||||
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
|
||||||
|
|
||||||
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
|
||||||
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
|
||||||
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
|
||||||
|
|
||||||
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
|
||||||
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
|
||||||
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
|
||||||
}
|
|
||||||
|
|
||||||
++n_fuse;
|
|
||||||
|
|
||||||
if (debug_fusion > 1 && n_fuse > 1) {
|
|
||||||
if (n_fuse == 2) {
|
|
||||||
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
|
||||||
}
|
|
||||||
if (n_fuse == 3) {
|
|
||||||
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_fuse > 1) {
|
|
||||||
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
|
||||||
|
|
||||||
for (int i = 1; i < n_fuse; ++i) {
|
|
||||||
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
|
||||||
ggml_metal_op_concurrency_reset(ctx);
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
|
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
|
||||||
|
|
||||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
||||||
nth *= 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
||||||
nth = std::min(nth, ne00/4);
|
|
||||||
|
|
||||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
||||||
|
|
||||||
return n_fuse;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_cgraph * gf = ctx->gf;
|
ggml_cgraph * gf = ctx->gf;
|
||||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||||
|
|
@ -2594,6 +2451,14 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_library_t lib = ctx->lib;
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
ggml_metal_encoder_t enc = ctx->enc;
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
const int idx_end = ctx->idx_end;
|
||||||
|
|
||||||
|
const bool use_fusion = ctx->use_fusion;
|
||||||
|
|
||||||
|
const int debug_fusion = ctx->debug_fusion;
|
||||||
|
|
||||||
|
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
|
|
@ -2602,37 +2467,121 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, op->op_params, sizeof(float));
|
memcpy(&eps, op->op_params, sizeof(float));
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
ggml_metal_kargs_norm args = {
|
ggml_metal_kargs_norm args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne00_4 =*/ ne00/4,
|
/*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.nb2 =*/ nb2,
|
||||||
|
/*.nb3 =*/ nb3,
|
||||||
/*.eps =*/ eps,
|
/*.eps =*/ eps,
|
||||||
|
/*.nef1 =*/ { ne01 },
|
||||||
|
/*.nef2 =*/ { ne02 },
|
||||||
|
/*.nef3 =*/ { ne03 },
|
||||||
|
/*.nbf1 =*/ { nb01 },
|
||||||
|
/*.nbf2 =*/ { nb02 },
|
||||||
|
/*.nbf3 =*/ { nb03 },
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op);
|
ggml_op fops[8];
|
||||||
|
|
||||||
|
int n_fuse = 1;
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
|
||||||
|
|
||||||
|
// d[0] = norm(a)
|
||||||
|
// d[1] = mul(d[0], b)
|
||||||
|
// d[2] = add(d[1], c)
|
||||||
|
if (use_fusion) {
|
||||||
|
fops[0] = op->op;
|
||||||
|
fops[1] = GGML_OP_MUL;
|
||||||
|
fops[2] = GGML_OP_ADD;
|
||||||
|
|
||||||
|
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||||
|
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
||||||
|
|
||||||
|
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
||||||
|
|
||||||
|
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
||||||
|
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
||||||
|
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
||||||
|
|
||||||
|
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
||||||
|
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
||||||
|
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
++n_fuse;
|
||||||
|
|
||||||
|
if (debug_fusion > 1 && n_fuse > 1) {
|
||||||
|
if (n_fuse == 2) {
|
||||||
|
GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
|
||||||
|
}
|
||||||
|
if (n_fuse == 3) {
|
||||||
|
GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_fuse > 1) {
|
||||||
|
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
||||||
|
|
||||||
|
for (int i = 1; i < n_fuse; ++i) {
|
||||||
|
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
||||||
|
while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
nth = std::min(nth, ne00/4);
|
nth = std::min(nth, args.ne00_t);
|
||||||
|
|
||||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
|
||||||
return 1;
|
return n_fuse;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,6 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
|
|
||||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,10 @@ static inline float e8m0_to_fp32(uint8_t x) {
|
||||||
return as_type<float>(bits);
|
return as_type<float>(bits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float dot(float x, float y) {
|
||||||
|
return x*y;
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||||
|
|
@ -2493,30 +2497,43 @@ kernel void kernel_argmax_f32(
|
||||||
dst_i32[tgpig] = arg_val;
|
dst_i32[tgpig] = arg_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_norm_f32(
|
// F == 1 : norm (no fuse)
|
||||||
|
// F == 2 : norm + mul
|
||||||
|
// F == 3 : norm + mul + add
|
||||||
|
template <typename T, short F>
|
||||||
|
kernel void kernel_norm_fuse_impl(
|
||||||
constant ggml_metal_kargs_norm & args,
|
constant ggml_metal_kargs_norm & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
device const char * src1_0,
|
||||||
|
device const char * src1_1,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort ntg[[threads_per_threadgroup]]) {
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
shmem_f32[tiisg] = 0.0f;
|
shmem_f32[tiisg] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
const int i01 = tgpig.x;
|
||||||
|
const int i02 = tgpig.y;
|
||||||
|
const int i03 = tgpig.z;
|
||||||
|
|
||||||
float4 sumf4(0.0f);
|
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||||
|
|
||||||
|
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||||
|
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||||
|
|
||||||
|
T sumft(0.0f);
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||||
sumf4 += x[i00];
|
sumft += x[i00];
|
||||||
}
|
}
|
||||||
sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
|
sumf = dot(sumft, T(1.0f));
|
||||||
sumf = simd_sum(sumf);
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
@ -2532,10 +2549,10 @@ kernel void kernel_norm_f32(
|
||||||
|
|
||||||
const float mean = sumf/args.ne00;
|
const float mean = sumf/args.ne00;
|
||||||
|
|
||||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||||
|
|
||||||
sumf = 0.0f;
|
sumf = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||||
y[i00] = x[i00] - mean;
|
y[i00] = x[i00] - mean;
|
||||||
sumf += dot(y[i00], y[i00]);
|
sumf += dot(y[i00], y[i00]);
|
||||||
}
|
}
|
||||||
|
|
@ -2555,17 +2572,35 @@ kernel void kernel_norm_f32(
|
||||||
const float variance = sumf/args.ne00;
|
const float variance = sumf/args.ne00;
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + args.eps);
|
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||||
y[i00] = y[i00] * scale;
|
if (F == 1) {
|
||||||
|
y[i00] = (y[i00]*scale);
|
||||||
|
}
|
||||||
|
if (F == 2) {
|
||||||
|
y[i00] = (y[i00]*scale)*f0[i00];
|
||||||
|
}
|
||||||
|
if (F == 3) {
|
||||||
|
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
|
||||||
|
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
|
||||||
|
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
|
||||||
|
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
|
||||||
|
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
|
||||||
|
|
||||||
// F == 1 : rms_norm (no fuse)
|
// F == 1 : rms_norm (no fuse)
|
||||||
// F == 2 : rms_norm + mul
|
// F == 2 : rms_norm + mul
|
||||||
// F == 3 : rms_norm + mul + add
|
// F == 3 : rms_norm + mul + add
|
||||||
template <short F>
|
template <typename T, short F>
|
||||||
kernel void kernel_rms_norm_fuse_impl(
|
kernel void kernel_rms_norm_fuse_impl(
|
||||||
constant ggml_metal_kargs_rms_norm & args,
|
constant ggml_metal_kargs_norm & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1_0,
|
device const char * src1_0,
|
||||||
device const char * src1_1,
|
device const char * src1_1,
|
||||||
|
|
@ -2584,15 +2619,15 @@ kernel void kernel_rms_norm_fuse_impl(
|
||||||
const int i02 = tgpig.y;
|
const int i02 = tgpig.y;
|
||||||
const int i03 = tgpig.z;
|
const int i03 = tgpig.z;
|
||||||
|
|
||||||
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||||
|
|
||||||
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||||
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||||
sumf += dot(x[i00], x[i00]);
|
sumf += dot(x[i00], x[i00]);
|
||||||
}
|
}
|
||||||
sumf = simd_sum(sumf);
|
sumf = simd_sum(sumf);
|
||||||
|
|
@ -2611,8 +2646,8 @@ kernel void kernel_rms_norm_fuse_impl(
|
||||||
const float mean = sumf/args.ne00;
|
const float mean = sumf/args.ne00;
|
||||||
const float scale = 1.0f/sqrt(mean + args.eps);
|
const float scale = 1.0f/sqrt(mean + args.eps);
|
||||||
|
|
||||||
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||||
if (F == 1) {
|
if (F == 1) {
|
||||||
y[i00] = (x[i00]*scale);
|
y[i00] = (x[i00]*scale);
|
||||||
}
|
}
|
||||||
|
|
@ -2625,11 +2660,15 @@ kernel void kernel_rms_norm_fuse_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
|
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
|
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
|
||||||
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
|
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
|
||||||
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
|
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
|
||||||
|
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||||
|
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||||
|
|
||||||
kernel void kernel_l2_norm_f32(
|
kernel void kernel_l2_norm_f32(
|
||||||
constant ggml_metal_kargs_l2_norm & args,
|
constant ggml_metal_kargs_l2_norm & args,
|
||||||
|
|
|
||||||
|
|
@ -6117,7 +6117,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
}
|
}
|
||||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue