metal : add cumsum (#17305)
This commit is contained in:
parent
2376b7758c
commit
1a139644a8
|
|
@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|
|
||||||
|
|
@ -612,6 +612,45 @@ typedef struct {
|
||||||
uint64_t nb3;
|
uint64_t nb3;
|
||||||
} ggml_metal_kargs_sum_rows;
|
} ggml_metal_kargs_sum_rows;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int64_t ne00;
|
||||||
|
int64_t ne01;
|
||||||
|
int64_t ne02;
|
||||||
|
int64_t ne03;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int64_t net0;
|
||||||
|
int64_t net1;
|
||||||
|
int64_t net2;
|
||||||
|
int64_t net3;
|
||||||
|
uint64_t nbt0;
|
||||||
|
uint64_t nbt1;
|
||||||
|
uint64_t nbt2;
|
||||||
|
uint64_t nbt3;
|
||||||
|
bool outb;
|
||||||
|
} ggml_metal_kargs_cumsum_blk;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int64_t ne00;
|
||||||
|
int64_t ne01;
|
||||||
|
int64_t ne02;
|
||||||
|
int64_t ne03;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int64_t net0;
|
||||||
|
int64_t net1;
|
||||||
|
int64_t net2;
|
||||||
|
int64_t net3;
|
||||||
|
uint64_t nbt0;
|
||||||
|
uint64_t nbt1;
|
||||||
|
uint64_t nbt2;
|
||||||
|
uint64_t nbt3;
|
||||||
|
} ggml_metal_kargs_cumsum_add;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne01;
|
int32_t ne01;
|
||||||
|
|
|
||||||
|
|
@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
{
|
||||||
|
n_fuse = ggml_metal_op_cumsum(ctx, idx);
|
||||||
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
||||||
|
|
@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||||
|
|
||||||
|
|
@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
|
@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float bias;
|
float bias;
|
||||||
|
|
@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float min;
|
float min;
|
||||||
float max;
|
float max;
|
||||||
|
|
@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
int64_t n = ggml_nelements(op);
|
int64_t n = ggml_nelements(op);
|
||||||
|
|
||||||
|
|
@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
if (op->src[1]) {
|
if (op->src[1]) {
|
||||||
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
||||||
|
|
@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
||||||
|
|
||||||
//[encoder setComputePipelineState:pipeline];
|
|
||||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
||||||
//if (src1) {
|
|
||||||
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
||||||
//} else {
|
|
||||||
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
||||||
//}
|
|
||||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
||||||
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
||||||
|
|
||||||
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
|
@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_kargs_sum_rows args = {
|
ggml_metal_kargs_sum_rows args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
|
|
@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||||
|
|
||||||
//[encoder setComputePipelineState:pipeline];
|
|
||||||
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
||||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
||||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
||||||
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
||||||
|
|
||||||
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
|
@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||||
|
|
||||||
|
int nth = 1;
|
||||||
|
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 <= nth*nth);
|
||||||
|
|
||||||
|
const int64_t net0 = (ne00 + nth - 1) / nth;
|
||||||
|
const int64_t net1 = ne01;
|
||||||
|
const int64_t net2 = ne02;
|
||||||
|
const int64_t net3 = ne03;
|
||||||
|
|
||||||
|
const uint64_t nbt0 = sizeof(float);
|
||||||
|
const uint64_t nbt1 = net0*nbt0;
|
||||||
|
const uint64_t nbt2 = net1*nbt1;
|
||||||
|
const uint64_t nbt3 = net2*nbt2;
|
||||||
|
|
||||||
|
const size_t smem = GGML_PAD(32*sizeof(float), 16);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||||
|
bid_tmp.offs += ggml_nbytes(op);
|
||||||
|
|
||||||
|
{
|
||||||
|
ggml_metal_kargs_cumsum_blk args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.net0 =*/ net0,
|
||||||
|
/*.net1 =*/ net1,
|
||||||
|
/*.net2 =*/ net2,
|
||||||
|
/*.net3 =*/ net3,
|
||||||
|
/*.nbt0 =*/ nbt0,
|
||||||
|
/*.nbt1 =*/ nbt1,
|
||||||
|
/*.nbt2 =*/ nbt2,
|
||||||
|
/*.nbt3 =*/ nbt3,
|
||||||
|
/*.outb =*/ ne00 > nth,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ne00 > nth) {
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
|
||||||
|
{
|
||||||
|
ggml_metal_kargs_cumsum_blk args = {
|
||||||
|
/*.ne00 =*/ net0,
|
||||||
|
/*.ne01 =*/ net1,
|
||||||
|
/*.ne02 =*/ net2,
|
||||||
|
/*.ne03 =*/ net3,
|
||||||
|
/*.nb00 =*/ nbt0,
|
||||||
|
/*.nb01 =*/ nbt1,
|
||||||
|
/*.nb02 =*/ nbt2,
|
||||||
|
/*.nb03 =*/ nbt3,
|
||||||
|
/*.net0 =*/ net0,
|
||||||
|
/*.net1 =*/ net1,
|
||||||
|
/*.net2 =*/ net2,
|
||||||
|
/*.net3 =*/ net3,
|
||||||
|
/*.nbt0 =*/ nbt0,
|
||||||
|
/*.nbt1 =*/ nbt1,
|
||||||
|
/*.nbt2 =*/ nbt2,
|
||||||
|
/*.nbt3 =*/ nbt3,
|
||||||
|
/*.outb =*/ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
|
||||||
|
{
|
||||||
|
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||||
|
|
||||||
|
ggml_metal_kargs_cumsum_add args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.net0 =*/ net0,
|
||||||
|
/*.net1 =*/ net1,
|
||||||
|
/*.net2 =*/ net2,
|
||||||
|
/*.net3 =*/ net3,
|
||||||
|
/*.nbt0 =*/ nbt0,
|
||||||
|
/*.nbt1 =*/ nbt1,
|
||||||
|
/*.nbt2 =*/ nbt2,
|
||||||
|
/*.nbt3 =*/ nbt3,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_tensor * op = ctx->node(idx);
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
|
@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||||
|
|
||||||
|
|
@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||||
|
|
||||||
|
|
@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
|
@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_kargs_ssm_conv args = {
|
ggml_metal_kargs_ssm_conv args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
|
|
@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const ggml_tensor * src3 = op->src[3];
|
const ggml_tensor * src3 = op->src[3];
|
||||||
const ggml_tensor * src4 = op->src[4];
|
const ggml_tensor * src4 = op->src[4];
|
||||||
|
|
@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
||||||
const int64_t T = op->src[0]->ne[2];
|
const int64_t T = op->src[0]->ne[2];
|
||||||
|
|
@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||||
|
|
||||||
|
|
@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int32_t * opts = op->op_params;
|
const int32_t * opts = op->op_params;
|
||||||
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
||||||
|
|
@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
|
|
@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
// src2 = ids
|
// src2 = ids
|
||||||
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
||||||
|
|
@ -2689,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, op->op_params, sizeof(float));
|
memcpy(&eps, op->op_params, sizeof(float));
|
||||||
|
|
@ -2737,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
||||||
|
|
||||||
|
|
@ -2792,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, op->op_params, sizeof(float));
|
memcpy(&eps, op->op_params, sizeof(float));
|
||||||
|
|
@ -2928,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
// make sure we have one or more position id(ne10) per token(ne02)
|
// make sure we have one or more position id(ne10) per token(ne02)
|
||||||
GGML_ASSERT(ne10 % ne02 == 0);
|
GGML_ASSERT(ne10 % ne02 == 0);
|
||||||
|
|
@ -3022,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||||
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
||||||
|
|
@ -3172,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||||
|
|
||||||
|
|
@ -3217,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||||
|
|
||||||
|
|
@ -3271,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||||
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||||
|
|
@ -3324,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_kargs_pad args = {
|
ggml_metal_kargs_pad args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
|
|
@ -3368,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_kargs_pad_reflect_1d args = {
|
ggml_metal_kargs_pad_reflect_1d args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
|
|
@ -3412,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_encoder_t enc = ctx->enc;
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float start;
|
float start;
|
||||||
float step;
|
float step;
|
||||||
|
|
@ -3430,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||||
|
|
||||||
//[encoder setComputePipelineState:pipeline];
|
|
||||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
||||||
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
||||||
|
|
||||||
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||||
|
|
@ -3454,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
const int dim = op->op_params[0];
|
const int dim = op->op_params[0];
|
||||||
const int max_period = op->op_params[1];
|
const int max_period = op->op_params[1];
|
||||||
|
|
@ -3488,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_kargs_argmax args = {
|
ggml_metal_kargs_argmax args = {
|
||||||
/*.ne00 = */ ne00,
|
/*.ne00 = */ ne00,
|
||||||
|
|
@ -3529,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||||
|
|
||||||
|
|
@ -3539,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nptg = (ne00 + nth - 1)/nth;
|
const int npr = (ne00 + nth - 1)/nth;
|
||||||
|
|
||||||
// Metal kernels require the buffer size to be multiple of 16 bytes
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
||||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||||
|
|
@ -3551,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||||
bid_tmp.offs += ggml_nbytes(op);
|
bid_tmp.offs += ggml_nbytes(op);
|
||||||
|
|
||||||
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
|
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
||||||
std::swap(bid_dst, bid_tmp);
|
std::swap(bid_dst, bid_tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3573,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||||
|
|
||||||
|
|
@ -3626,7 +3747,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
float slope;
|
float slope;
|
||||||
memcpy(&slope, op->op_params, sizeof(float));
|
memcpy(&slope, op->op_params, sizeof(float));
|
||||||
|
|
@ -3662,7 +3783,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||||
|
|
||||||
|
|
@ -3698,7 +3819,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
res *= 2;
|
res *= 2;
|
||||||
|
|
|
||||||
|
|
@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
kernel void kernel_cumsum_blk(
|
||||||
|
constant ggml_metal_kargs_cumsum_blk & args,
|
||||||
|
device const char * src0,
|
||||||
|
device char * tmp,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int ib = tgpig[0]/args.ne01;
|
||||||
|
|
||||||
|
const int i00 = ib*ntg.x;
|
||||||
|
const int i01 = tgpig[0]%args.ne01;
|
||||||
|
const int i02 = tgpig[1];
|
||||||
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
|
device const float * src0_row = (device const float *) (src0 +
|
||||||
|
args.nb01*i01 +
|
||||||
|
args.nb02*i02 +
|
||||||
|
args.nb03*i03);
|
||||||
|
|
||||||
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||||
|
|
||||||
|
float v = 0.0f;
|
||||||
|
|
||||||
|
if (i00 + tpitg.x < args.ne00) {
|
||||||
|
v = src0_row[i00 + tpitg.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
float s = simd_prefix_inclusive_sum(v);
|
||||||
|
|
||||||
|
if (tiisg == N_SIMDWIDTH - 1) {
|
||||||
|
shmem_f32[sgitg] = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
|
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
s += shmem_f32[sgitg];
|
||||||
|
|
||||||
|
device float * dst_row = (device float *) dst +
|
||||||
|
args.ne00*i01 +
|
||||||
|
args.ne00*args.ne01*i02 +
|
||||||
|
args.ne00*args.ne01*args.ne02*i03;
|
||||||
|
|
||||||
|
if (i00 + tpitg.x < args.ne00) {
|
||||||
|
dst_row[i00 + tpitg.x] = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.outb && tpitg.x == ntg.x - 1) {
|
||||||
|
device float * tmp_row = (device float *) tmp +
|
||||||
|
args.net0*i01 +
|
||||||
|
args.net0*args.net1*i02 +
|
||||||
|
args.net0*args.net1*args.net2*i03;
|
||||||
|
|
||||||
|
tmp_row[ib] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
kernel void kernel_cumsum_add(
|
||||||
|
constant ggml_metal_kargs_cumsum_add & args,
|
||||||
|
device const char * tmp,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int ib = tgpig[0]/args.ne01;
|
||||||
|
|
||||||
|
if (ib == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i00 = ib*ntg.x;
|
||||||
|
const int i01 = tgpig[0]%args.ne01;
|
||||||
|
const int i02 = tgpig[1];
|
||||||
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
|
device const float * tmp_row = (device const float *) (tmp +
|
||||||
|
args.nbt1*i01 +
|
||||||
|
args.nbt2*i02 +
|
||||||
|
args.nbt3*i03);
|
||||||
|
|
||||||
|
device float * dst_row = (device float *) dst +
|
||||||
|
args.ne00*i01 +
|
||||||
|
args.ne00*args.ne01*i02 +
|
||||||
|
args.ne00*args.ne01*args.ne02*i03;
|
||||||
|
|
||||||
|
if (i00 + tpitg.x < args.ne00) {
|
||||||
|
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
constant ggml_metal_kargs_soft_max & args,
|
constant ggml_metal_kargs_soft_max & args,
|
||||||
|
|
@ -4543,7 +4654,7 @@ typedef void (argsort_t)(
|
||||||
constant ggml_metal_kargs_argsort & args,
|
constant ggml_metal_kargs_argsort & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]]);
|
ushort3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32(
|
||||||
constant ggml_metal_kargs_argsort & args,
|
constant ggml_metal_kargs_argsort & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32(
|
||||||
const int i02 = tgpig[1];
|
const int i02 = tgpig[1];
|
||||||
const int i03 = tgpig[2];
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||||
|
|
||||||
// initialize indices
|
// initialize indices
|
||||||
smem_i32[col] = i00 + col;
|
shmem_i32[col] = i00 + col;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
|
@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32(
|
||||||
int ixj = col ^ j;
|
int ixj = col ^ j;
|
||||||
if (ixj > col) {
|
if (ixj > col) {
|
||||||
if ((col & k) == 0) {
|
if ((col & k) == 0) {
|
||||||
if (smem_i32[col] >= args.ne00 ||
|
if (shmem_i32[col] >= args.ne00 ||
|
||||||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
|
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
||||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
|
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
||||||
) {
|
) {
|
||||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (smem_i32[ixj] >= args.ne00 ||
|
if (shmem_i32[ixj] >= args.ne00 ||
|
||||||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
|
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
||||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
|
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
||||||
) {
|
) {
|
||||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32(
|
||||||
if (i00 + col < args.ne00) {
|
if (i00 + col < args.ne00) {
|
||||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
||||||
|
|
||||||
dst[col] = smem_i32[col];
|
dst[col] = shmem_i32[col];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7558,7 +7558,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_arange());
|
test_cases.emplace_back(new test_arange());
|
||||||
test_cases.emplace_back(new test_timestep_embedding());
|
test_cases.emplace_back(new test_timestep_embedding());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
test_cases.emplace_back(new test_cumsum());
|
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
|
||||||
|
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_xielu());
|
test_cases.emplace_back(new test_xielu());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue