ggml : add new GGML_OP_WHERE_ID (akin to torch where but using indices)
This commit is contained in:
parent
3eb340ed4b
commit
08dc7fd9d9
|
|
@ -558,6 +558,7 @@ extern "C" {
|
|||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_GATED_DELTA_NET,
|
||||
GGML_OP_HADAMARD,
|
||||
GGML_OP_WHERE_ID,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
|
@ -2479,6 +2480,12 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
int n);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_where_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
|
|
|||
|
|
@ -2029,6 +2029,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_hadamard(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WHERE_ID:
|
||||
{
|
||||
ggml_compute_forward_where_id(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
|
|
@ -2352,6 +2356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_HADAMARD:
|
||||
case GGML_OP_WHERE_ID:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -11256,3 +11256,81 @@ void ggml_compute_forward_hadamard(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_where_id
|
||||
|
||||
static void ggml_compute_forward_where_id_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
const float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 );
|
||||
const int32_t * ids_ptr = (int32_t *) ((char *) src2->data + i3*nb23 + i2*nb22 + i1*nb21);
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
|
||||
// copy whole row from src1
|
||||
ggml_vec_cpy_f32(ne00, dst_ptr, src1_ptr);
|
||||
|
||||
// copy only values from src0 indicated by indices in src2
|
||||
for (int j = 0; j < ne20; ++j) {
|
||||
int id = ids_ptr[j];
|
||||
GGML_ASSERT(id >= 0 && id < ne00);
|
||||
dst_ptr[id] = src0_ptr[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_where_id(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_where_id_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_where_id: %s", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
|
|||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_where_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@
|
|||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/fill.cuh"
|
||||
#include "ggml-cuda/hadamard.cuh"
|
||||
#include "ggml-cuda/where-id.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -2775,6 +2776,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_HADAMARD:
|
||||
ggml_cuda_op_hadamard(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_WHERE_ID:
|
||||
ggml_cuda_op_where_id(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -5016,6 +5020,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_WHERE_ID:
|
||||
return true;
|
||||
case GGML_OP_HADAMARD:
|
||||
return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
#include "where-id.cuh"
|
||||
|
||||
static __global__ void where_id_kernel(
|
||||
const float * src0, const int32_t * src1, float * dst,
|
||||
int64_t ne10, int64_t ne11, int64_t ne12,
|
||||
size_t nb1, size_t nb2,
|
||||
size_t nb01, size_t nb02,
|
||||
size_t nb11, size_t nb12
|
||||
) {
|
||||
|
||||
const int64_t total_blocks = ne11 * ne12;
|
||||
|
||||
for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) {
|
||||
|
||||
const int64_t i1 = block_idx % ne11;
|
||||
const int64_t i2 = block_idx / ne11;
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
||||
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
|
||||
const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) {
|
||||
const int32_t id = src1_row[i0];
|
||||
dst_row[id] = src0_row[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src2));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
|
||||
|
||||
// step 1 - copy whole src1 to dst
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
char * dst_ddc = (char *) dst->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src1_ddc, ggml_nbytes(src1), cudaMemcpyDeviceToDevice, main_stream));
|
||||
|
||||
// step 2 - copy elements from src0 indicated by ids to dst
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const int32_t * src2_d = (const int32_t *) src2->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
int threads = std::min((int) ne20, 768); // ids
|
||||
|
||||
int64_t total_blocks = ne21 * ne22;
|
||||
int blocks = (int) std::min((int64_t) 65535, total_blocks);
|
||||
|
||||
where_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src0_d, src2_d, dst_d,
|
||||
ne20, ne21, ne22,
|
||||
nb1, nb2,
|
||||
nb01, nb02,
|
||||
nb21, nb22
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1033,6 +1033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"SOLVE_TRI",
|
||||
"GATED_DELTA_NET",
|
||||
"HADAMARD",
|
||||
"WHERE_ID",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
|
@ -1050,7 +1051,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
|
||||
static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1144,6 +1145,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"A X = B, A triangular, solve X",
|
||||
"gated_delta_net(q, k, v, g, beta, s)",
|
||||
"hadamard(x)",
|
||||
"where_id(x,y,ids)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
|
@ -1161,7 +1163,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
|
||||
static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -6201,6 +6203,29 @@ struct ggml_tensor * ggml_hadamard(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_where_id
|
||||
|
||||
struct ggml_tensor * ggml_where_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids) {
|
||||
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[1] == ids->ne[1]);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
||||
|
||||
result->op = GGML_OP_WHERE_ID;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue