ggml : add ggml_top_k (#17365)
* ggml : add ggml_top_k * cont : add ggml_argsort_top_k * metal : add top_k support * ggml : cleanup * tests : add virtual err() function for test_case * ggml : add comments
This commit is contained in:
parent
05872ac885
commit
583cb83416
|
|
@ -530,6 +530,7 @@ extern "C" {
|
|||
GGML_OP_ARANGE,
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_TOP_K,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
|
|
@ -2258,18 +2259,25 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order);
|
||||
|
||||
// similar to ggml_top_k but implemented as `argsort` + `view`
|
||||
GGML_API struct ggml_tensor * ggml_argsort_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
// top k elements per row
|
||||
// note: the resulting top k indices are in no particular order
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step);
|
||||
|
||||
// top k elements per row
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||
|
|
|
|||
|
|
@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_argsort(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
ggml_compute_forward_top_k(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
|
|
@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_TOP_K:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
|
|
@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
|
|||
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
|
|
|
|||
|
|
@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
|
|||
// ggml_compute_forward_argsort
|
||||
|
||||
template<enum ggml_sort_order order>
|
||||
struct argsort_cmp {
|
||||
struct cmp_argsort {
|
||||
const float * data;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
||||
|
|
@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
|
|||
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC:
|
||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
||||
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
||||
break;
|
||||
|
||||
case GGML_SORT_ORDER_DESC:
|
||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
||||
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_top_k
|
||||
|
||||
struct cmp_top_k {
|
||||
const float * data;
|
||||
bool operator()(int32_t a, int32_t b) const {
|
||||
return data[a] > data[b];
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_top_k_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const int top_k = ne0;
|
||||
|
||||
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
for (int64_t i = ith; i < nr; i += nth) {
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
for (int64_t j = 0; j < ne00; j++) {
|
||||
tmp[j] = j;
|
||||
}
|
||||
|
||||
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
||||
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
|
||||
std::copy(tmp, tmp + top_k, dst_data);
|
||||
|
||||
// emphasize that the order is not important
|
||||
if (top_k > 1) {
|
||||
std::swap(dst_data[0], dst_data[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_top_k(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_top_k_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
|
|||
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
|
|||
return res;
|
||||
}
|
||||
|
||||
// note: reuse the argsort kernel for top_k
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_TOP_K);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
// note: the top_k kernel is always descending order
|
||||
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
||||
|
||||
const char * order_str = "undefined";
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_TOP_K);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
||||
|
||||
const char * order_str = "undefined";
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
|
|
|
|||
|
|
@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_TOP_K:
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
|
|
|
|||
|
|
@ -832,14 +832,19 @@ typedef struct {
|
|||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
int32_t top_k;
|
||||
} ggml_metal_kargs_argsort;
|
||||
|
||||
typedef struct {
|
||||
|
|
@ -851,6 +856,11 @@ typedef struct {
|
|||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
int32_t top_k;
|
||||
int32_t len;
|
||||
} ggml_metal_kargs_argsort_merge;
|
||||
|
||||
|
|
|
|||
|
|
@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_argsort(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
|
|
@ -3686,6 +3690,11 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.top_k =*/ nth,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
|
|
@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
ggml_metal_kargs_argsort_merge args_merge = {
|
||||
.ne00 = ne00,
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb00 = nb00,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.len = len,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.top_k =*/ ne00,
|
||||
/*.len =*/ len,
|
||||
};
|
||||
|
||||
// merges per row
|
||||
|
|
@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
// blocks per row
|
||||
const int npr = (ne00 + nth - 1)/nth;
|
||||
|
||||
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
|
||||
|
||||
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
}
|
||||
|
||||
const int top_k = ne0;
|
||||
|
||||
ggml_metal_kargs_argsort args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
|
||||
};
|
||||
|
||||
if (npr > 1) {
|
||||
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||
|
||||
int len = args.top_k;
|
||||
|
||||
while (len < args.ne0) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
// merges per row
|
||||
const int nm = (args.ne0 + 2*len - 1) / (2*len);
|
||||
|
||||
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
|
||||
|
||||
ggml_metal_kargs_argsort_merge args_merge = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ args.ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
|
||||
/*.len =*/ len,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
|
||||
len <<= 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
|||
{
|
||||
res *= 2;
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4670,9 +4670,10 @@ kernel void kernel_argsort_f32_i32(
|
|||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// bitonic sort
|
||||
const int col = tpitg[0];
|
||||
const int ib = tgpig[0] / args.ne01;
|
||||
|
||||
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
|
|
@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
|
|||
}
|
||||
}
|
||||
|
||||
const int64_t i0 = ib*args.top_k;
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (i00 + col < args.ne00) {
|
||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
||||
if (i0 + col < args.ne0 && col < args.top_k) {
|
||||
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
||||
|
||||
dst[col] = shmem_i32[col];
|
||||
}
|
||||
|
|
@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
|
|||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||
|
||||
const int total = len0 + len1;
|
||||
|
||||
device const int32_t * tmp0 = tmp + start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
+ i01*args.ne0
|
||||
+ i02*args.ne0*args.ne01
|
||||
+ i03*args.ne0*args.ne01*args.ne02;
|
||||
|
||||
device const int32_t * tmp1 = tmp0 + args.len;
|
||||
|
||||
dst += start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
+ i01*args.top_k
|
||||
+ i02*args.top_k*args.ne01
|
||||
+ i03*args.top_k*args.ne01*args.ne02;
|
||||
|
||||
device const float * src0_row = (device const float *)(src0
|
||||
+ args.nb01*i01
|
||||
|
|
@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
|
|||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = min(k0 + chunk, total);
|
||||
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||
|
||||
if (k0 >= args.top_k) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"ARANGE",
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
"TOP_K",
|
||||
"LEAKY_RELU",
|
||||
"TRI",
|
||||
"FILL",
|
||||
|
|
@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"arange(start, stop, step)",
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
"top_k(x)",
|
||||
"leaky_relu(x)",
|
||||
"tri(x)",
|
||||
"fill(x, c)",
|
||||
|
|
@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_arange
|
||||
|
||||
struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step) {
|
||||
GGML_ASSERT(stop > start);
|
||||
|
||||
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, start);
|
||||
ggml_set_op_params_f32(result, 1, stop);
|
||||
ggml_set_op_params_f32(result, 2, step);
|
||||
|
||||
result->op = GGML_OP_ARANGE;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_timestep_embedding
|
||||
|
||||
struct ggml_tensor * ggml_timestep_embedding(
|
||||
|
|
@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
|
|||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order) {
|
||||
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
||||
|
|
@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_argsort_top_k
|
||||
|
||||
struct ggml_tensor * ggml_argsort_top_k(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int k) {
|
||||
GGML_ASSERT(a->ne[0] >= k);
|
||||
|
||||
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
|
||||
|
||||
result = ggml_view_4d(ctx, result,
|
||||
k, result->ne[1], result->ne[2], result->ne[3],
|
||||
result->nb[1], result->nb[2], result->nb[3],
|
||||
0);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_top_k
|
||||
|
||||
struct ggml_tensor * ggml_top_k(
|
||||
|
|
@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
|
|||
int k) {
|
||||
GGML_ASSERT(a->ne[0] >= k);
|
||||
|
||||
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
|
||||
|
||||
result = ggml_view_4d(ctx, result,
|
||||
k, result->ne[1], result->ne[2], result->ne[3],
|
||||
result->nb[1], result->nb[2], result->nb[3],
|
||||
0);
|
||||
result->op = GGML_OP_TOP_K;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_arange
|
||||
|
||||
struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step) {
|
||||
GGML_ASSERT(stop > start);
|
||||
|
||||
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, start);
|
||||
ggml_set_op_params_f32(result, 1, stop);
|
||||
ggml_set_op_params_f32(result, 2, step);
|
||||
|
||||
result->op = GGML_OP_ARANGE;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
// organize experts into n_expert_groups
|
||||
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||
|
||||
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
||||
|
||||
// get top n_group_used expert groups
|
||||
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
||||
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
||||
|
||||
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||
|
||||
// mask out the other groups
|
||||
|
|
@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
cb(selected_experts, "ffn_moe_topk", il);
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@
|
|||
#include <string_view>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||
size_t nels = ggml_nelements(tensor);
|
||||
|
|
@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) {
|
|||
return mse_a_b / mse_a_0;
|
||||
}
|
||||
|
||||
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
||||
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
|
||||
std::unordered_map<int32_t, size_t> set_a;
|
||||
std::unordered_map<int32_t, size_t> set_b;
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
set_a[a[i]]++;
|
||||
set_b[b[i]]++;
|
||||
}
|
||||
|
||||
size_t diff = 0;
|
||||
|
||||
for (const auto & p : set_a) {
|
||||
const int64_t na = p.second;
|
||||
const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
|
||||
|
||||
diff += std::abs(na - nb);
|
||||
}
|
||||
|
||||
for (const auto & p : set_b) {
|
||||
if (set_a.find(p.first) == set_a.end()) {
|
||||
diff += p.second;
|
||||
}
|
||||
}
|
||||
|
||||
return (double) diff / (2*n);
|
||||
}
|
||||
|
||||
// maximum absolute asymmetry between a and b
|
||||
// asymmetry: (a - b) / (a + b)
|
||||
// This is more stable than relative error if one of the values fluctuates towards zero.
|
||||
|
|
@ -1051,6 +1080,14 @@ struct test_case {
|
|||
return 1e-4;
|
||||
}
|
||||
|
||||
virtual double max_err() {
|
||||
return max_nmse_err();
|
||||
}
|
||||
|
||||
virtual double err(const float * a, const float * b, size_t n) {
|
||||
return nmse(a, b, n);
|
||||
}
|
||||
|
||||
virtual float grad_eps() {
|
||||
return 1e-1f;
|
||||
}
|
||||
|
|
@ -1257,16 +1294,16 @@ struct test_case {
|
|||
// compare
|
||||
struct callback_userdata {
|
||||
bool ok;
|
||||
double max_err;
|
||||
test_case * tc;
|
||||
ggml_backend_t backend1;
|
||||
ggml_backend_t backend2;
|
||||
};
|
||||
|
||||
callback_userdata ud {
|
||||
true,
|
||||
max_nmse_err(),
|
||||
this,
|
||||
backend1,
|
||||
backend2
|
||||
backend2,
|
||||
};
|
||||
|
||||
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
|
||||
|
|
@ -1314,9 +1351,9 @@ struct test_case {
|
|||
}
|
||||
}
|
||||
|
||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||
if (err > ud->max_err) {
|
||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||
double err = ud->tc->err(f1.data(), f2.data(), f1.size());
|
||||
if (err > ud->tc->max_err()) {
|
||||
printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err());
|
||||
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
//}
|
||||
|
|
@ -4943,7 +4980,71 @@ struct test_argsort : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
struct test_topk_moe: public test_case {
|
||||
// GGML_OP_TOP_K
|
||||
struct test_top_k : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int k;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, k);
|
||||
}
|
||||
|
||||
test_top_k(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
||||
int k = 4)
|
||||
: type(type), ne(ne), k(k) {}
|
||||
|
||||
double max_err() override {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
double err(const float * a, const float * b, size_t n) override {
|
||||
std::vector<int32_t> ia(n);
|
||||
std::vector<int32_t> ib(n);
|
||||
|
||||
double diff = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ia[i] = (int32_t) a[i];
|
||||
ib[i] = (int32_t) b[i];
|
||||
|
||||
// penalize the result if the data is not integer valued
|
||||
diff += std::fabs(a[i] - ia[i]);
|
||||
diff += std::fabs(b[i] - ib[i]);
|
||||
}
|
||||
|
||||
return diff + jdst(ia.data(), ib.data(), n);
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_top_k(ctx, a, k);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
// initialize with unique values to avoid ties
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<float> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct test_topk_moe : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int n_expert_used;
|
||||
const bool with_norm;
|
||||
|
|
@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case {
|
|||
|
||||
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
|
||||
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
|
||||
|
|
@ -7534,6 +7635,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
}
|
||||
|
||||
for (int k : {1, 2, 3, 7, 15}) {
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
|
||||
}
|
||||
|
||||
// exhaustive top_k tests
|
||||
//for (int i = 1; i < 9999; ++i) {
|
||||
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
|
||||
//}
|
||||
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
||||
|
|
@ -7914,6 +8032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue