diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 547ccc42aa..82186fe8f6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -558,6 +558,7 @@ extern "C" { GGML_OP_SOLVE_TRI, GGML_OP_GATED_DELTA_NET, GGML_OP_HADAMARD, + GGML_OP_WHERE_ID, GGML_OP_UNARY, @@ -2479,6 +2480,12 @@ extern "C" { struct ggml_tensor * a, int n); + GGML_API struct ggml_tensor * ggml_where_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index bed01ae65c..e5e5f0507e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2029,6 +2029,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_hadamard(params, tensor); } break; + case GGML_OP_WHERE_ID: + { + ggml_compute_forward_where_id(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2352,6 +2356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: case GGML_OP_HADAMARD: + case GGML_OP_WHERE_ID: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 111a474a6f..c4a77b29e9 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11256,3 +11256,81 @@ void ggml_compute_forward_hadamard( } } } + +// ggml_compute_forward_where_id + +static void ggml_compute_forward_where_id_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 ); + const float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 ); + const int32_t * ids_ptr = (int32_t *) ((char *) src2->data + i3*nb23 + i2*nb22 + i1*nb21); + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + + // copy whole row from src1 + ggml_vec_cpy_f32(ne00, dst_ptr, src1_ptr); + + // copy only values from src0 indicated by indices in src2 + for (int j = 0; j < ne20; ++j) { + int id = ids_ptr[j]; + GGML_ASSERT(id >= 0 && id < ne00); + dst_ptr[id] = src0_ptr[id]; + } + } +} + +void ggml_compute_forward_where_id( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_where_id_f32(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_where_id: %s", ggml_type_name(src0->type)); + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index c28d32ea91..30b3e6d311 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -104,6 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_where_id(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6a091a6d8a..da2b54e137 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -62,6 +62,7 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml-cuda/hadamard.cuh" +#include "ggml-cuda/where-id.cuh" #include "ggml.h" #include @@ -2775,6 +2776,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_HADAMARD: ggml_cuda_op_hadamard(ctx, dst); break; + case GGML_OP_WHERE_ID: + ggml_cuda_op_where_id(ctx, dst); + break; default: return false; } @@ -5016,6 +5020,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_SOLVE_TRI: + case GGML_OP_WHERE_ID: return true; case GGML_OP_HADAMARD: return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-cuda/where-id.cu b/ggml/src/ggml-cuda/where-id.cu new file mode 100644 index 0000000000..993873462b --- /dev/null +++ b/ggml/src/ggml-cuda/where-id.cu @@ -0,0 +1,77 @@ +#include "where-id.cuh" + +static __global__ void where_id_kernel( + const float * src0, const int32_t * src1, float * dst, + int64_t ne10, int64_t ne11, int64_t ne12, + size_t nb1, size_t nb2, + size_t nb01, size_t nb02, + size_t nb11, size_t nb12 + ) { + + const int64_t total_blocks = ne11 * ne12; + + for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { + + const int64_t i1 = block_idx % ne11; + const int64_t i2 = block_idx / ne11; + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); + const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12); + + for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) { + const int32_t id = src1_row[i0]; + dst_row[id] = src0_row[id]; + } + } +} + +void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(src2)); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(int32_t)); + + GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); + GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst)); + + // step 1 - copy whole src1 to dst + cudaStream_t main_stream = ctx.stream(); + char * dst_ddc = (char *) dst->data; + char * src1_ddc = (char *) src1->data; + + CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src1_ddc, ggml_nbytes(src1), cudaMemcpyDeviceToDevice, main_stream)); + + // step 2 - copy elements from src0 indicated by ids to dst + const float * src0_d = (const float *) src0->data; + const int32_t * src2_d = (const int32_t *) src2->data; + float * dst_d = (float *) dst->data; + + int threads = std::min((int) ne20, 768); // ids + + int64_t total_blocks = ne21 * ne22; + int blocks = (int) std::min((int64_t) 65535, total_blocks); + + where_id_kernel<<>>( + src0_d, src2_d, dst_d, + ne20, ne21, ne22, + nb1, nb2, + nb01, nb02, + nb21, nb22 + ); +} diff --git a/ggml/src/ggml-cuda/where-id.cuh b/ggml/src/ggml-cuda/where-id.cuh new file mode 100644 index 0000000000..bf3ea095a8 --- /dev/null +++ b/ggml/src/ggml-cuda/where-id.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a01ee49ee3..7132c1f215 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1033,6 +1033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOLVE_TRI", "GATED_DELTA_NET", "HADAMARD", + "WHERE_ID", "UNARY", @@ -1050,7 +1051,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); +static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1144,6 +1145,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "A X = B, A triangular, solve X", "gated_delta_net(q, k, v, g, beta, s)", "hadamard(x)", + "where_id(x,y,ids)", "unary(x)", @@ -1161,7 +1163,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); +static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6201,6 +6203,29 @@ struct ggml_tensor * ggml_hadamard( return result; } +// ggml_where_id + +struct ggml_tensor * ggml_where_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids) { + + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[1] == ids->ne[1]); + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_WHERE_ID; + result->src[0] = a; + result->src[1] = b; + result->src[2] = ids; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) {