metal : extend l2_norm support for non-cont src0 (#19502)
This commit is contained in:
parent
ada90bf2ba
commit
9ab072ebbe
|
|
@ -1480,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
|
||||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
assert(op->op == GGML_OP_L2_NORM);
|
assert(op->op == GGML_OP_L2_NORM);
|
||||||
|
|
||||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
||||||
|
|
||||||
char base[256];
|
char base[256];
|
||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
snprintf(base, 256, "kernel_l2_norm_f32");
|
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||||
|
|
||||||
|
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||||
|
const char * t_str = ggml_type_name(op->type);
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||||
snprintf(name, 256, "%s", base);
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
|
@ -1494,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res.c4 = is_c4;
|
||||||
res.smem = 32*sizeof(float);
|
res.smem = 32*sizeof(float);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
|
||||||
|
|
@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
return has_simdgroup_reduction &&
|
return has_simdgroup_reduction &&
|
||||||
op->src[0]->type == GGML_TYPE_I32 &&
|
op->src[0]->type == GGML_TYPE_I32 &&
|
||||||
|
|
|
||||||
|
|
@ -539,8 +539,21 @@ typedef struct {
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne00_4;
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
int32_t ne03;
|
||||||
|
uint64_t nb00;
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int32_t ne2;
|
||||||
|
int32_t ne3;
|
||||||
|
uint64_t nb0;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
uint64_t nb3;
|
||||||
float eps;
|
float eps;
|
||||||
} ggml_metal_kargs_l2_norm;
|
} ggml_metal_kargs_l2_norm;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, op->op_params, sizeof(float));
|
memcpy(&eps, op->op_params, sizeof(float));
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
|
||||||
|
|
||||||
ggml_metal_kargs_l2_norm args = {
|
ggml_metal_kargs_l2_norm args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne00_4 =*/ ne00/4,
|
/*.ne01 =*/ ne01,
|
||||||
/*.nb01 =*/ nb01,
|
/*.ne02 =*/ ne02,
|
||||||
/*.eps =*/ eps,
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.ne2 =*/ ne2,
|
||||||
|
/*.ne3 =*/ ne3,
|
||||||
|
/*.nb0 =*/ nb0,
|
||||||
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.nb2 =*/ nb2,
|
||||||
|
/*.nb3 =*/ nb3,
|
||||||
|
/*.eps =*/ eps,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||||
|
|
||||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
if (pipeline.c4) {
|
||||||
|
args.ne00 = ne00/4;
|
||||||
|
args.ne0 = ne0/4;
|
||||||
|
}
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
nth = std::min(nth, ne00/4);
|
|
||||||
|
|
||||||
const size_t smem = pipeline.smem;
|
const size_t smem = pipeline.smem;
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
|
||||||
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||||
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||||
|
|
||||||
kernel void kernel_l2_norm_f32(
|
template <typename T0, typename T>
|
||||||
|
kernel void kernel_l2_norm_impl(
|
||||||
constant ggml_metal_kargs_l2_norm & args,
|
constant ggml_metal_kargs_l2_norm & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort ntg[[threads_per_threadgroup]]) {
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int i03 = tgpig.z;
|
||||||
|
const int i02 = tgpig.y;
|
||||||
|
const int i01 = tgpig.x;
|
||||||
|
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
shmem_f32[tiisg] = 0.0f;
|
shmem_f32[tiisg] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||||
|
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||||
sumf += dot(x[i00], x[i00]);
|
sumf += dot(x[i00], x[i00]);
|
||||||
}
|
}
|
||||||
sumf = simd_sum(sumf);
|
sumf = simd_sum(sumf);
|
||||||
|
|
@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32(
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||||
|
|
||||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
|
||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
||||||
|
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
||||||
|
|
||||||
kernel void kernel_group_norm_f32(
|
kernel void kernel_group_norm_f32(
|
||||||
constant ggml_metal_kargs_group_norm & args,
|
constant ggml_metal_kargs_group_norm & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue