Merge 646038f776 into 58062860af
This commit is contained in:
commit
53e629feeb
|
|
@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
const ggml_type tsrc1 = GGML_TYPE_F32;
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (!res.pipeline) {
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
res.smem = 32*sizeof(float);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -908,6 +908,17 @@ typedef struct {
|
||||||
int64_t np;
|
int64_t np;
|
||||||
} ggml_metal_kargs_pool_2d;
|
} ggml_metal_kargs_pool_2d;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t nrows;
|
||||||
|
int32_t k;
|
||||||
|
} ggml_metal_kargs_cross_entropy_loss;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t nrows;
|
||||||
|
} ggml_metal_kargs_cross_entropy_loss_back;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
|
|
|
||||||
|
|
@ -1333,6 +1333,46 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
const ggml_tensor * src0 = op->src[0]; // NOTE: logits
|
||||||
|
const ggml_tensor * src1 = op->src[1]; // NOTE: labels
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
const int32_t ne00 = src0->ne[0];
|
||||||
|
const int32_t nrows = ggml_nrows(src1);
|
||||||
|
ggml_metal_kargs_cross_entropy_loss args = {
|
||||||
|
/*int32_t*/ ne00,
|
||||||
|
/*int32_t*/ nrows,
|
||||||
|
/*int32_t*/ nrows,
|
||||||
|
};
|
||||||
|
int nth = 32;
|
||||||
|
auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op);
|
||||||
|
|
||||||
|
const size_t smem = pipeline.smem;
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
if (op->src[1]) {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
|
} else {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_tensor * op = ctx->node(idx);
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -2280,6 +2280,146 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
|
||||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_cross_entropy_loss(
|
||||||
|
constant ggml_metal_kargs_cross_entropy_loss & args,
|
||||||
|
device const float * logits,
|
||||||
|
device const float * labels,
|
||||||
|
device float * dst,
|
||||||
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint tptg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
const int row = tgpig;
|
||||||
|
device const float * logits_row = logits + row * args.ne00;
|
||||||
|
device const float * labels_row = labels + row * args.ne00;
|
||||||
|
|
||||||
|
float lmax = - INFINITY;
|
||||||
|
for (int i = tpitg; i < args.ne00; i+= tptg){
|
||||||
|
lmax = MAX(lmax, logits_row[i]);
|
||||||
|
}
|
||||||
|
float max_val = simd_max(lmax);
|
||||||
|
if (tptg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) buf[tiisg] = -INFINITY;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (tiisg == 0) buf[sgitg] = max_val;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
max_val = buf[tiisg];
|
||||||
|
max_val = simd_max(max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
float lsum = 0.0f;
|
||||||
|
for (int i = tpitg; i < args.ne00; i += tptg){
|
||||||
|
const float exp_val = exp(logits_row[i] - max_val);
|
||||||
|
lsum += exp_val;
|
||||||
|
dst[i] = exp_val;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
float sum = simd_sum(lsum);
|
||||||
|
if (tptg > N_SIMDWIDTH){
|
||||||
|
if (sgitg == 0) buf[tiisg] = 0.0f;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (tiisg == 0) buf[sgitg] = sum;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sum = buf[tiisg];
|
||||||
|
sum = simd_sum(sum);
|
||||||
|
}
|
||||||
|
const float log_sum = log(sum);
|
||||||
|
float lloss = 0.0f;
|
||||||
|
for (int i = tpitg; i < args.ne00; i += tptg){
|
||||||
|
const float log_softmax_i = logits_row[i] - max_val - log_sum;
|
||||||
|
lloss += log_softmax_i * logits_row[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float loss = simd_sum(lloss);
|
||||||
|
if (tptg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) buf[tiisg] = 0.0f;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (tiisg == 0) buf[sgitg] = loss;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loss = buf[tiisg];
|
||||||
|
loss = simd_sum(loss);
|
||||||
|
}
|
||||||
|
if (tpitg == 0) {
|
||||||
|
dst[row] = -loss / args.nrows;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_cross_entropy_loss_back(
|
||||||
|
constant ggml_metal_kargs_cross_entropy_loss_back & args,
|
||||||
|
device const float * grad,
|
||||||
|
device const float * logits, // src0
|
||||||
|
device const float * labels, // src1
|
||||||
|
device float * dst, // output
|
||||||
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint tptg[[threads_per_threadgroup]]){
|
||||||
|
|
||||||
|
const int row = tpitg;
|
||||||
|
|
||||||
|
device const float * logits_row = logits + row * args.ne00;
|
||||||
|
device const float * labels_row = labels + row * args.ne00;
|
||||||
|
device float * dst_row = dst + row * args.ne00;
|
||||||
|
|
||||||
|
// find max
|
||||||
|
|
||||||
|
float lmax = - INFINITY;
|
||||||
|
for (int i = tpitg; i < args.ne00; i+= tptg){
|
||||||
|
lmax = MAX(lmax, logits_row[i]);
|
||||||
|
}
|
||||||
|
float max_val = simd_max(lmax);
|
||||||
|
if (tptg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) buf[tiisg] = -INFINITY;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (tiisg == 0) buf[sgitg] = max_val;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
max_val = buf[tiisg];
|
||||||
|
max_val = simd_max(max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
float lsum = 0.0f;
|
||||||
|
for (int i = tpitg; i < args.ne00; i += tptg){
|
||||||
|
const float exp_val = exp(logits_row[i] - max_val);
|
||||||
|
lsum += exp_val;
|
||||||
|
dst_row[i] = exp_val;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
float sum = simd_sum(lsum);
|
||||||
|
if (tptg > N_SIMDWIDTH){
|
||||||
|
if (sgitg == 0) buf[tiisg] = 0.0f;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (tiisg == 0) buf[sgitg] = sum;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sum = buf[tiisg];
|
||||||
|
sum = simd_sum(sum);
|
||||||
|
}
|
||||||
|
const float inv_sum = 1.0f / sum;
|
||||||
|
const float d_by_nrows = grad[0] / args.nrows;
|
||||||
|
|
||||||
|
for (int i = tpitg; i < args.ne00; i += tptg){
|
||||||
|
const float softmax_i = dst_row[i] * inv_sum; // exp(logits - max)/ sum(exp(logits - val))
|
||||||
|
dst_row[i] = (softmax_i - labels_row[i]) * d_by_nrows;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;
|
||||||
|
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cross_entropy_loss_f32")]]
|
||||||
|
kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cross_entropy_loss_back_f32")]]
|
||||||
|
kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back<float>;
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||||
kernel void kernel_ssm_conv_f32_f32(
|
kernel void kernel_ssm_conv_f32_f32(
|
||||||
constant ggml_metal_kargs_ssm_conv & args,
|
constant ggml_metal_kargs_ssm_conv & args,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue