ggml : replaced GGML_OP_WHERE_ID with GGML_OP_SCATTER that works similar to torch scatter_ operation.
This commit is contained in:
parent
4309c8486a
commit
9b0a4eea57
|
|
@ -558,7 +558,7 @@ extern "C" {
|
|||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_GATED_DELTA_NET,
|
||||
GGML_OP_HADAMARD,
|
||||
GGML_OP_WHERE_ID,
|
||||
GGML_OP_SCATTER,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
|
@ -2480,11 +2480,11 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
int n);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_where_id(
|
||||
GGML_API struct ggml_tensor * ggml_scatter(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids);
|
||||
struct ggml_tensor * ids,
|
||||
float c);
|
||||
|
||||
// custom operators
|
||||
|
||||
|
|
|
|||
|
|
@ -2029,9 +2029,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_hadamard(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WHERE_ID:
|
||||
case GGML_OP_SCATTER:
|
||||
{
|
||||
ggml_compute_forward_where_id(params, tensor);
|
||||
ggml_compute_forward_scatter(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
|
|
@ -2356,7 +2356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_HADAMARD:
|
||||
case GGML_OP_WHERE_ID:
|
||||
case GGML_OP_SCATTER:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -11257,32 +11257,30 @@ void ggml_compute_forward_hadamard(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_where_id
|
||||
// ggml_compute_forward_scatter
|
||||
|
||||
static void ggml_compute_forward_where_id_f32(
|
||||
static void ggml_compute_forward_scatter_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
|
@ -11301,23 +11299,22 @@ static void ggml_compute_forward_where_id_f32(
|
|||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
const float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 );
|
||||
const int32_t * ids_ptr = (int32_t *) ((char *) src2->data + i3*nb23 + i2*nb22 + i1*nb21);
|
||||
const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
|
||||
// copy whole row from src1
|
||||
ggml_vec_cpy_f32(ne00, dst_ptr, src1_ptr);
|
||||
// copy whole row from src0
|
||||
ggml_vec_cpy_f32(ne00, dst_ptr, src0_ptr);
|
||||
|
||||
// copy only values from src0 indicated by indices in src2
|
||||
for (int j = 0; j < ne20; ++j) {
|
||||
// set dst elements indicated by indices in src1 to c
|
||||
for (int j = 0; j < ne10; ++j) {
|
||||
int id = ids_ptr[j];
|
||||
GGML_ASSERT(id >= 0 && id < ne00);
|
||||
dst_ptr[id] = src0_ptr[id];
|
||||
dst_ptr[id] = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_where_id(
|
||||
void ggml_compute_forward_scatter(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
|
|
@ -11326,11 +11323,11 @@ void ggml_compute_forward_where_id(
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_where_id_f32(params, dst);
|
||||
ggml_compute_forward_scatter_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_where_id: %s", ggml_type_name(src0->type));
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
|
|||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_where_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_scatter(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@
|
|||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/fill.cuh"
|
||||
#include "ggml-cuda/hadamard.cuh"
|
||||
#include "ggml-cuda/where-id.cuh"
|
||||
#include "ggml-cuda/scatter.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -2776,8 +2776,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_HADAMARD:
|
||||
ggml_cuda_op_hadamard(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_WHERE_ID:
|
||||
ggml_cuda_op_where_id(ctx, dst);
|
||||
case GGML_OP_SCATTER:
|
||||
ggml_cuda_op_scatter(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -5020,7 +5020,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_WHERE_ID:
|
||||
case GGML_OP_SCATTER:
|
||||
return true;
|
||||
case GGML_OP_HADAMARD:
|
||||
return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
#include "scatter.cuh"
|
||||
|
||||
static __global__ void scatter_kernel(
|
||||
const int32_t * src0, float * dst, const float c,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
||||
size_t nb1, size_t nb2, size_t nb3,
|
||||
size_t nb01, size_t nb02, size_t nb03
|
||||
) {
|
||||
|
||||
const int64_t total_blocks = ne01 * ne02 * ne03;
|
||||
|
||||
for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) {
|
||||
|
||||
const int64_t i1 = block_idx % ne01;
|
||||
const int64_t i2 = (block_idx / ne01) % ne02;
|
||||
const int64_t i3 = block_idx / (ne01 * ne02);
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
const int * src0_row = (const int *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
|
||||
const int32_t id = src0_row[i0];
|
||||
dst_row[id] = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(int32_t));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
|
||||
|
||||
float c;
|
||||
memcpy(&c, (float *) dst->op_params + 0, sizeof(float));
|
||||
|
||||
// step 1 - copy whole src0 to dst
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
char * dst_ddc = (char *) dst->data;
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
|
||||
// step 2 - set elements in dst indicated by ids to c
|
||||
const int32_t * src1_d = (const int32_t *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
int threads = std::min((int) ne10, 768); // ids
|
||||
|
||||
int64_t total_blocks = ne11 * ne12 * ne13;
|
||||
int blocks = (int) std::min((int64_t) 65535, total_blocks);
|
||||
|
||||
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src1_d, dst_d, c,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb1, nb2, nb3,
|
||||
nb11, nb12, nb13
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
#include "where-id.cuh"
|
||||
|
||||
static __global__ void where_id_kernel(
|
||||
const float * src0, const int32_t * src1, float * dst,
|
||||
int64_t ne10, int64_t ne11, int64_t ne12, int64_t ne13,
|
||||
size_t nb1, size_t nb2, size_t nb3,
|
||||
size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t nb11, size_t nb12, size_t nb13
|
||||
) {
|
||||
|
||||
const int64_t total_blocks = ne11 * ne12 * ne13;
|
||||
|
||||
for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) {
|
||||
|
||||
const int64_t i1 = block_idx % ne11;
|
||||
const int64_t i2 = (block_idx / ne11) % ne12;
|
||||
const int64_t i3 = block_idx / (ne11 * ne12);
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12 + i3*nb13);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) {
|
||||
const int32_t id = src1_row[i0];
|
||||
dst_row[id] = src0_row[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src2));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
|
||||
|
||||
// step 1 - copy whole src1 to dst
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
char * dst_ddc = (char *) dst->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src1_ddc, ggml_nbytes(src1), cudaMemcpyDeviceToDevice, main_stream));
|
||||
|
||||
// step 2 - copy elements from src0 indicated by ids to dst
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const int32_t * src2_d = (const int32_t *) src2->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
int threads = std::min((int) ne20, 768); // ids
|
||||
|
||||
int64_t total_blocks = ne21 * ne22 * ne23;
|
||||
int blocks = (int) std::min((int64_t) 65535, total_blocks);
|
||||
|
||||
where_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src0_d, src2_d, dst_d,
|
||||
ne20, ne21, ne22, ne23,
|
||||
nb1, nb2, nb3,
|
||||
nb01, nb02, nb03,
|
||||
nb21, nb22, nb23
|
||||
);
|
||||
}
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1033,7 +1033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"SOLVE_TRI",
|
||||
"GATED_DELTA_NET",
|
||||
"HADAMARD",
|
||||
"WHERE_ID",
|
||||
"SCATTER",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
|
@ -1145,7 +1145,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"A X = B, A triangular, solve X",
|
||||
"gated_delta_net(q, k, v, g, beta, s)",
|
||||
"hadamard(x)",
|
||||
"where_id(x,y,ids)",
|
||||
"scatter(x,ids,c)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
|
@ -6203,25 +6203,26 @@ struct ggml_tensor * ggml_hadamard(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_where_id
|
||||
// ggml_scatter
|
||||
|
||||
struct ggml_tensor * ggml_where_id(
|
||||
struct ggml_tensor * ggml_scatter(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids) {
|
||||
struct ggml_tensor * ids,
|
||||
float c) {
|
||||
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[1] == ids->ne[1]);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
||||
|
||||
result->op = GGML_OP_WHERE_ID;
|
||||
float params[1] = { c };
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SCATTER;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = ids;
|
||||
result->src[1] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2177,7 +2177,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
||||
|
||||
// modify it by unmasking tokens that are in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
|
||||
ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0);
|
||||
|
||||
// combine with the original kq mask
|
||||
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32);
|
||||
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
|
|
|
|||
Loading…
Reference in New Issue