ggml : added f16 version of GGML_OP_SCATTER
This commit is contained in:
parent
54945c7ec1
commit
5677f082b0
|
|
@ -11318,6 +11318,68 @@ static void ggml_compute_forward_scatter_f32(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_scatter_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
|
||||
const bool inplace = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
|
||||
// copy whole row from src0
|
||||
if (!inplace) {
|
||||
// ggml_vec_cpy_f16(ne00, dst_ptr, src0_ptr)
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
dst_ptr[i] = src0_ptr[i];
|
||||
}
|
||||
}
|
||||
|
||||
// set dst elements indicated by indices in src1 to c
|
||||
for (int j = 0; j < ne10; ++j) {
|
||||
int id = ids_ptr[j];
|
||||
GGML_ASSERT(id >= 0 && id < ne00);
|
||||
dst_ptr[id] = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_scatter(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -11329,6 +11391,10 @@ void ggml_compute_forward_scatter(
|
|||
{
|
||||
ggml_compute_forward_scatter_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_scatter_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type));
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
#include "scatter.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
template <typename T>
|
||||
static __global__ void scatter_kernel(
|
||||
const int32_t * src0, float * dst, const float c,
|
||||
const int32_t * src0, T * dst, const T c,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
||||
size_t nb1, size_t nb2, size_t nb3,
|
||||
size_t nb01, size_t nb02, size_t nb03
|
||||
|
|
@ -15,7 +17,7 @@ static __global__ void scatter_kernel(
|
|||
const int64_t i2 = (block_idx / ne01) % ne02;
|
||||
const int64_t i3 = block_idx / (ne01 * ne02);
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
T * dst_row = (T *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
const int * src0_row = (const int *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
|
||||
|
|
@ -35,11 +37,9 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == src0->type);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(int32_t));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
|
||||
|
|
@ -58,17 +58,31 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
// step 2 - set elements in dst indicated by ids to c
|
||||
const int32_t * src1_d = (const int32_t *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
void * dst_d = dst->data;
|
||||
|
||||
int threads = std::min((int) ne10, 512); // ids
|
||||
|
||||
int64_t total_blocks = ne11 * ne12 * ne13;
|
||||
int blocks = (int) std::min((int64_t) 65535, total_blocks);
|
||||
|
||||
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src1_d, dst_d, c,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb1, nb2, nb3,
|
||||
nb11, nb12, nb13
|
||||
);
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src1_d, (float *) dst_d, c,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb1, nb2, nb3,
|
||||
nb11, nb12, nb13
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src1_d, (half *) dst_d, ggml_cuda_cast<half>(c),
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb1, nb2, nb3,
|
||||
nb11, nb12, nb13
|
||||
);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6212,7 +6212,7 @@ static struct ggml_tensor * ggml_scatter_impl(
|
|||
float c,
|
||||
bool inplace) {
|
||||
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[1] == ids->ne[1]);
|
||||
|
||||
|
|
|
|||
|
|
@ -8538,6 +8538,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, true));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false));
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
|
@ -8798,6 +8802,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
// scatter
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true));
|
||||
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false));
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue